From 1486f4b4c4e1c667ebe11c8127c26ad45d969706 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 13 Feb 2024 09:57:22 -0800 Subject: [PATCH 01/27] refactor message --- dotnet/.editorconfig | 3 - dotnet/Directory.Build.props | 2 + .../AutoGen/Core/Agent/ConversableAgent.cs | 2 +- .../AutoGen/Core/Agent/GroupChatManager.cs | 2 +- .../src/AutoGen/Core/Agent/MiddlewareAgent.cs | 2 +- dotnet/src/AutoGen/Core/IAgent.cs | 6 +- dotnet/src/AutoGen/Core/Message.cs | 146 +++++++++++++++++- .../OpenAI/Extension/MessageExtension.cs | 94 +++++++++++ dotnet/src/AutoGen/OpenAI/GPTAgent.cs | 8 +- 9 files changed, 251 insertions(+), 14 deletions(-) diff --git a/dotnet/.editorconfig b/dotnet/.editorconfig index 3a4f76bcc18..4da1adc5de6 100644 --- a/dotnet/.editorconfig +++ b/dotnet/.editorconfig @@ -167,9 +167,6 @@ dotnet_diagnostic.IDE0005.severity = error # IDE0069: Remove unused local variable dotnet_diagnostic.IDE0069.severity = error -# IDE0060: Remove unused parameter -dotnet_diagnostic.IDE0060.severity = warning - # disable CS1573: Parameter has no matching param tag in the XML comment for dotnet_diagnostic.CS1573.severity = none diff --git a/dotnet/Directory.Build.props b/dotnet/Directory.Build.props index 141e1c23adc..03a11d92c23 100644 --- a/dotnet/Directory.Build.props +++ b/dotnet/Directory.Build.props @@ -13,6 +13,8 @@ $(NoWarn);$(CSNoWarn);NU5104 true false + true + true diff --git a/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs b/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs index 0995667eb66..4683add5897 100644 --- a/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs @@ -106,7 +106,7 @@ public enum HumanInputMode return agent; } - public string? Name { get; } + public string Name { get; } public Func, CancellationToken, Task>? IsTermination { get; } diff --git a/dotnet/src/AutoGen/Core/Agent/GroupChatManager.cs b/dotnet/src/AutoGen/Core/Agent/GroupChatManager.cs index cae2b713ac7..da38e253491 100644 --- a/dotnet/src/AutoGen/Core/Agent/GroupChatManager.cs +++ b/dotnet/src/AutoGen/Core/Agent/GroupChatManager.cs @@ -15,7 +15,7 @@ public GroupChatManager(IGroupChat groupChat) { GroupChat = groupChat; } - public string? Name => throw new ArgumentException("GroupChatManager does not have a name"); + public string Name => throw new ArgumentException("GroupChatManager does not have a name"); public IEnumerable? Messages { get; private set; } diff --git a/dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs b/dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs index 0627badfdca..bb2eeea1619 100644 --- a/dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs @@ -45,7 +45,7 @@ public MiddlewareAgent(IAgent innerAgent, string? name = null) this.Name = name ?? innerAgent.Name; } - public string? Name { get; } + public string Name { get; } public Task GenerateReplyAsync( IEnumerable messages, diff --git a/dotnet/src/AutoGen/Core/IAgent.cs b/dotnet/src/AutoGen/Core/IAgent.cs index 0cf37f03012..3e4ab709236 100644 --- a/dotnet/src/AutoGen/Core/IAgent.cs +++ b/dotnet/src/AutoGen/Core/IAgent.cs @@ -10,7 +10,7 @@ namespace AutoGen; public interface IAgent { - public string? Name { get; } + public string Name { get; } /// /// Generate reply @@ -28,8 +28,8 @@ public interface IAgent /// public interface IStreamingReplyAgent : IAgent { - public Task> GenerateReplyStreamingAsync( - IEnumerable messages, + public Task> GenerateReplyStreamingAsync( + IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/AutoGen/Core/Message.cs b/dotnet/src/AutoGen/Core/Message.cs index 5457983135d..f3639077811 100644 --- a/dotnet/src/AutoGen/Core/Message.cs +++ b/dotnet/src/AutoGen/Core/Message.cs @@ -1,12 +1,156 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Message.cs +using System; using System.Collections.Generic; using Azure.AI.OpenAI; namespace AutoGen; -public class Message +public interface IMessage +{ + string? From { get; set; } +} + +internal class TextMessage : IMessage +{ + public TextMessage(Role role, string content, string? from = null) + { + this.Content = content; + this.Role = role; + this.From = from; + } + + public Role Role { get; set; } + + public string Content { get; } + + public string? From { get; set; } + + public bool Equals(IMessage other) + { + throw new NotImplementedException(); + } +} + +internal class ImageMessage : IMessage +{ + public ImageMessage(Role role, string url, string? from = null) + { + this.Role = role; + this.From = from; + this.Url = url; + } + + public Role Role { get; set; } + + public string Url { get; set; } + + public string? From { get; set; } + + public bool Equals(IMessage other) + { + throw new NotImplementedException(); + } +} + +internal class ToolCallMessage : IMessage +{ + public ToolCallMessage(Role role, string functionName, string functionArgs, string? from = null) + { + this.Role = role; + this.From = from; + this.FunctionName = functionName; + this.FunctionArguments = functionArgs; + } + + public Role Role { get; set; } + + public string FunctionName { get; set; } + + public string FunctionArguments { get; set; } + + public string? From { get; set; } + + public bool Equals(IMessage other) + { + throw new NotImplementedException(); + } +} + +internal class ToolCallResultMessage : IMessage +{ + public ToolCallResultMessage(Role role, string result, ToolCallMessage toolCallMessage, string? from = null) + { + this.Role = role; + this.From = from; + this.Result = result; + this.ToolCallMessage = toolCallMessage; + } + + public Role Role { get; set; } + + /// + /// The original tool call message + /// + public ToolCallMessage ToolCallMessage { get; set; } + + /// + /// The result from the tool call + /// + public string Result { get; set; } + + public string? From { get; set; } + + public bool Equals(IMessage other) + { + throw new NotImplementedException(); + } +} + +internal class AggregateMessage : IMessage +{ + public AggregateMessage(IList messages, string? from = null) + { + this.From = from; + this.Messages = messages; + this.Validate(); + } + + public IList Messages { get; set; } + + public string? From { get; set; } + + private void Validate() + { + // the from property of all messages should be the same with the from property of the aggregate message + foreach (var message in this.Messages) + { + if (message.From != this.From) + { + var reason = $"The from property of the message {message} is different from the from property of the aggregate message {this}"; + throw new ArgumentException($"Invalid aggregate message {reason}"); + } + } + + // no nested aggregate message + foreach (var message in this.Messages) + { + if (message is AggregateMessage) + { + var reason = $"The message {message} is an aggregate message"; + throw new ArgumentException("Invalid aggregate message " + reason); + } + } + } + + public bool Equals(IMessage other) + { + throw new NotImplementedException(); + } +} + +public class Message : IMessage { public Message( Role role, diff --git a/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs b/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs index d7e5a6c4409..9cb957792b5 100644 --- a/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs +++ b/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using Azure.AI.OpenAI; namespace AutoGen.OpenAI; @@ -162,4 +163,97 @@ public static ChatRequestFunctionMessage ToChatRequestFunctionMessage(this Messa return functionMessage; } + + public static IEnumerable ToOpenAIChatRequestMessage(this IAgent agent, IMessage message) + { + if (message.From != agent.Name) + { + if (message is TextMessage textMessage) + { + if (textMessage.Role == Role.System) + { + return [new ChatRequestSystemMessage(textMessage.Content)]; + } + else + { + return [new ChatRequestUserMessage(textMessage.Content)]; + } + } + else if (message is ToolCallMessage) + { + throw new ArgumentException($"ToolCallMessage is not supported when message.From is not the same with agent"); + } + else if (message is ToolCallResultMessage toolCallResult) + { + return [new ChatRequestToolMessage(toolCallResult.Result, toolCallResult.ToolCallMessage.FunctionName)]; + } + else if (message is AggregateMessage aggregateMessage) + { + // if aggreate message contains a list of tool call result message, then it is a parallel tool call message + if (aggregateMessage.Messages.All(m => m is ToolCallResultMessage)) + { + return aggregateMessage.Messages.Select(message => new ChatRequestToolMessage((message as ToolCallResultMessage)!.Result, (message as ToolCallResultMessage)!.ToolCallMessage.FunctionName)); + } + + // otherwise, it's a multi-modal message + IEnumerable messageContent = aggregateMessage.Messages.Select(m => + { + return m switch + { + TextMessage textMessage => new ChatMessageTextContentItem(textMessage.Content), + ImageMessage imageMessage => new ChatMessageImageContentItem(new Uri(imageMessage.Url)), + _ => throw new ArgumentException($"Unknown message type: {m.GetType()}") + }; + }); + + return [new ChatRequestUserMessage(messageContent)]; + } + else + { + throw new ArgumentException($"Unknown message type: {message.GetType()}"); + } + } + else + { + if (message is TextMessage textMessage) + { + if (textMessage.Role == Role.System) + { + throw new ArgumentException("System message is not supported when message.From is the same with agent"); + } + + return [new ChatRequestAssistantMessage(textMessage.Content)]; + } + else if (message is ToolCallMessage toolCallMessage) + { + // single tool call message + var assistantMessage = new ChatRequestAssistantMessage(string.Empty); + assistantMessage.ToolCalls.Add(new ChatCompletionsFunctionToolCall(toolCallMessage.FunctionName, toolCallMessage.FunctionName, toolCallMessage.FunctionArguments)); + + return [assistantMessage]; + } + else if (message is AggregateMessage aggregateMessage) + { + // parallel tool call messages + var assistantMessage = new ChatRequestAssistantMessage(string.Empty); + foreach (var m in aggregateMessage.Messages) + { + if (m is ToolCallMessage toolCall) + { + assistantMessage.ToolCalls.Add(new ChatCompletionsFunctionToolCall(toolCall.FunctionName, toolCall.FunctionName, toolCall.FunctionArguments)); + } + else + { + throw new ArgumentException($"Unknown message type: {m.GetType()}"); + } + } + + return [assistantMessage]; + } + else + { + throw new ArgumentException($"Unknown message type: {message.GetType()}"); + } + } + } } diff --git a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs index b54fafba360..b5fb55256f2 100644 --- a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs +++ b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs @@ -72,7 +72,7 @@ public class GPTAgent : IStreamingReplyAgent this.functionMap = functionMap; } - public string? Name { get; } + public string Name { get; } public async Task GenerateReplyAsync( IEnumerable messages, @@ -86,8 +86,8 @@ public class GPTAgent : IStreamingReplyAgent return await this.PostProcessMessage(oaiMessage); } - public async Task> GenerateReplyStreamingAsync( - IEnumerable messages, + public async Task> GenerateReplyStreamingAsync( + IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { @@ -180,7 +180,7 @@ await foreach (var chunk in response) } } - private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages) + private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages) { var oaiMessages = this.ProcessMessages(messages); var settings = new ChatCompletionsOptions(this.modelName, oaiMessages) From f876431cf27df60054c73f453545c3755f0a0975 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 13 Feb 2024 18:26:21 -0800 Subject: [PATCH 02/27] refactor over IMessage --- dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs | 2 +- dotnet/src/AutoGen/Core/Message.cs | 193 ----------- dotnet/src/AutoGen/Core/MessageEnvelope.cs | 314 ++++++++++++++++++ .../OpenAI/Extension/MessageExtension.cs | 125 ++++++- dotnet/src/AutoGen/OpenAI/GPTAgent.cs | 86 +---- ...MessageTests.BasicMessageTest.approved.txt | 195 +++++++++++ .../test/AutoGen.Tests/OpenAIMessageTests.cs | 170 ++++++++++ dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 39 ++- 8 files changed, 823 insertions(+), 301 deletions(-) delete mode 100644 dotnet/src/AutoGen/Core/Message.cs create mode 100644 dotnet/src/AutoGen/Core/MessageEnvelope.cs create mode 100644 dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt create mode 100644 dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs diff --git a/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs b/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs index e438abc14cc..14c8be0def2 100644 --- a/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs +++ b/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs @@ -42,7 +42,7 @@ public class LMStudioAgent : IAgent functionMap: functionMap); } - public string? Name => innerAgent.Name; + public string Name => innerAgent.Name; public Task GenerateReplyAsync( IEnumerable messages, diff --git a/dotnet/src/AutoGen/Core/Message.cs b/dotnet/src/AutoGen/Core/Message.cs deleted file mode 100644 index f3639077811..00000000000 --- a/dotnet/src/AutoGen/Core/Message.cs +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Message.cs - -using System; -using System.Collections.Generic; -using Azure.AI.OpenAI; - -namespace AutoGen; - -public interface IMessage -{ - string? From { get; set; } -} - -internal class TextMessage : IMessage -{ - public TextMessage(Role role, string content, string? from = null) - { - this.Content = content; - this.Role = role; - this.From = from; - } - - public Role Role { get; set; } - - public string Content { get; } - - public string? From { get; set; } - - public bool Equals(IMessage other) - { - throw new NotImplementedException(); - } -} - -internal class ImageMessage : IMessage -{ - public ImageMessage(Role role, string url, string? from = null) - { - this.Role = role; - this.From = from; - this.Url = url; - } - - public Role Role { get; set; } - - public string Url { get; set; } - - public string? From { get; set; } - - public bool Equals(IMessage other) - { - throw new NotImplementedException(); - } -} - -internal class ToolCallMessage : IMessage -{ - public ToolCallMessage(Role role, string functionName, string functionArgs, string? from = null) - { - this.Role = role; - this.From = from; - this.FunctionName = functionName; - this.FunctionArguments = functionArgs; - } - - public Role Role { get; set; } - - public string FunctionName { get; set; } - - public string FunctionArguments { get; set; } - - public string? From { get; set; } - - public bool Equals(IMessage other) - { - throw new NotImplementedException(); - } -} - -internal class ToolCallResultMessage : IMessage -{ - public ToolCallResultMessage(Role role, string result, ToolCallMessage toolCallMessage, string? from = null) - { - this.Role = role; - this.From = from; - this.Result = result; - this.ToolCallMessage = toolCallMessage; - } - - public Role Role { get; set; } - - /// - /// The original tool call message - /// - public ToolCallMessage ToolCallMessage { get; set; } - - /// - /// The result from the tool call - /// - public string Result { get; set; } - - public string? From { get; set; } - - public bool Equals(IMessage other) - { - throw new NotImplementedException(); - } -} - -internal class AggregateMessage : IMessage -{ - public AggregateMessage(IList messages, string? from = null) - { - this.From = from; - this.Messages = messages; - this.Validate(); - } - - public IList Messages { get; set; } - - public string? From { get; set; } - - private void Validate() - { - // the from property of all messages should be the same with the from property of the aggregate message - foreach (var message in this.Messages) - { - if (message.From != this.From) - { - var reason = $"The from property of the message {message} is different from the from property of the aggregate message {this}"; - throw new ArgumentException($"Invalid aggregate message {reason}"); - } - } - - // no nested aggregate message - foreach (var message in this.Messages) - { - if (message is AggregateMessage) - { - var reason = $"The message {message} is an aggregate message"; - throw new ArgumentException("Invalid aggregate message " + reason); - } - } - } - - public bool Equals(IMessage other) - { - throw new NotImplementedException(); - } -} - -public class Message : IMessage -{ - public Message( - Role role, - string? content, - string? from = null, - FunctionCall? functionCall = null) - { - this.Role = role; - this.Content = content; - this.From = from; - this.FunctionName = functionCall?.Name; - this.FunctionArguments = functionCall?.Arguments; - } - - public Message(Message other) - : this(other.Role, other.Content, other.From) - { - this.FunctionName = other.FunctionName; - this.FunctionArguments = other.FunctionArguments; - this.Value = other.Value; - this.Metadata = other.Metadata; - } - - public Role Role { get; set; } - - public string? Content { get; set; } - - public string? From { get; set; } - - public string? FunctionName { get; set; } - - public string? FunctionArguments { get; set; } - - /// - /// raw message - /// - public object? Value { get; set; } - - public IList> Metadata { get; set; } = new List>(); -} diff --git a/dotnet/src/AutoGen/Core/MessageEnvelope.cs b/dotnet/src/AutoGen/Core/MessageEnvelope.cs new file mode 100644 index 00000000000..b04bae44f44 --- /dev/null +++ b/dotnet/src/AutoGen/Core/MessageEnvelope.cs @@ -0,0 +1,314 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Message.cs + +using System; +using System.Collections.Generic; +using Azure.AI.OpenAI; + +namespace AutoGen; + +public interface IMessage +{ + string? From { get; set; } +} + +public interface IMessage : IMessage +{ + T Content { get; } +} + +internal class MessageEnvelope : IMessage +{ + public MessageEnvelope(T content, string? from = null, IDictionary? metadata = null) + { + this.Content = content; + this.From = from; + this.Metadata = metadata ?? new Dictionary(); + } + + public T Content { get; } + + public string? From { get; set; } + + public IDictionary Metadata { get; set; } +} + +public class TextMessage : IMessage +{ + public TextMessage(Role role, string content, string? from = null) + { + this.Content = content; + this.Role = role; + this.From = from; + } + + public Role Role { get; set; } + + public string Content { get; } + + public string? From { get; set; } + + public override string ToString() + { + return $"TextMessage({this.Role}, {this.Content}, {this.From})"; + } +} + +public class ImageMessage : IMessage +{ + public ImageMessage(Role role, string url, string? from = null) + { + this.Role = role; + this.From = from; + this.Url = url; + } + + public Role Role { get; set; } + + public string Url { get; set; } + + public string? From { get; set; } + + public override string ToString() + { + return $"ImageMessage({this.Role}, {this.Url}, {this.From})"; + } +} + +public class ToolCallMessage : IMessage +{ + public ToolCallMessage(string functionName, string functionArgs, string? from = null) + { + this.From = from; + this.FunctionName = functionName; + this.FunctionArguments = functionArgs; + } + + public string FunctionName { get; set; } + + public string FunctionArguments { get; set; } + + public string? From { get; set; } + + public override string ToString() + { + return $"ToolCallMessage({this.FunctionName}, {this.FunctionArguments}, {this.From})"; + } +} + +public class ToolCallResultMessage : IMessage +{ + public ToolCallResultMessage(string result, ToolCallMessage toolCallMessage, string? from = null) + { + this.From = from; + this.Result = result; + this.ToolCallMessage = toolCallMessage; + } + + /// + /// The original tool call message + /// + public ToolCallMessage ToolCallMessage { get; set; } + + /// + /// The result from the tool call + /// + public string Result { get; set; } + + public string? From { get; set; } + + public override string ToString() + { + return $"ToolCallResultMessage({this.Result}, {this.ToolCallMessage}, {this.From})"; + } +} + +public class MultiModalMessage : IMessage +{ + public MultiModalMessage(IEnumerable content, string? from = null) + { + this.Content = content; + this.From = from; + this.Validate(); + } + + public Role Role { get; set; } + + public IEnumerable Content { get; set; } + + public string? From { get; set; } + + private void Validate() + { + foreach (var message in this.Content) + { + if (message.From != this.From) + { + var reason = $"The from property of the message {message} is different from the from property of the aggregate message {this}"; + throw new ArgumentException($"Invalid aggregate message {reason}"); + } + } + + // all message must be either text or image + foreach (var message in this.Content) + { + if (message is not TextMessage && message is not ImageMessage) + { + var reason = $"The message {message} is not a text or image message"; + throw new ArgumentException($"Invalid aggregate message {reason}"); + } + } + } + + public override string ToString() + { + var stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append($"MultiModalMessage({this.Role}, {this.From})"); + foreach (var message in this.Content) + { + stringBuilder.Append($"\n\t{message}"); + } + + return stringBuilder.ToString(); + } +} + +public class ParallelToolCallResultMessage : IMessage +{ + public ParallelToolCallResultMessage(IEnumerable toolCallResult, string? from = null) + { + this.ToolCallResult = toolCallResult; + this.From = from; + this.Validate(); + } + + public IEnumerable ToolCallResult { get; set; } + + public string? From { get; set; } + + public bool Equals(IMessage other) + { + throw new NotImplementedException(); + } + + private void Validate() + { + // the from property of all messages should be the same with the from property of the aggregate message + foreach (var message in this.ToolCallResult) + { + if (message.From != this.From) + { + var reason = $"The from property of the message {message} is different from the from property of the aggregate message {this}"; + throw new ArgumentException($"Invalid aggregate message {reason}"); + } + } + } + + public override string ToString() + { + var stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append($"ParallelToolCallResultMessage({this.From})"); + foreach (var message in this.ToolCallResult) + { + stringBuilder.Append($"\n\t{message}"); + } + + return stringBuilder.ToString(); + } +} + +public class AggregateMessage : IMessage +{ + public AggregateMessage(IEnumerable messages, string? from = null) + { + this.From = from; + this.Messages = messages; + this.Validate(); + } + + public IEnumerable Messages { get; set; } + + public string? From { get; set; } + + private void Validate() + { + // the from property of all messages should be the same with the from property of the aggregate message + foreach (var message in this.Messages) + { + if (message.From != this.From) + { + var reason = $"The from property of the message {message} is different from the from property of the aggregate message {this}"; + throw new ArgumentException($"Invalid aggregate message {reason}"); + } + } + + // no nested aggregate message + foreach (var message in this.Messages) + { + if (message is AggregateMessage) + { + var reason = $"The message {message} is an aggregate message"; + throw new ArgumentException("Invalid aggregate message " + reason); + } + } + } + + public override string ToString() + { + var stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append($"AggregateMessage({this.From})"); + foreach (var message in this.Messages) + { + stringBuilder.Append($"\n\t{message}"); + } + + return stringBuilder.ToString(); + } +} + +public class Message : IMessage +{ + public Message( + Role role, + string? content, + string? from = null, + FunctionCall? functionCall = null) + { + this.Role = role; + this.Content = content; + this.From = from; + this.FunctionName = functionCall?.Name; + this.FunctionArguments = functionCall?.Arguments; + } + + public Message(Message other) + : this(other.Role, other.Content, other.From) + { + this.FunctionName = other.FunctionName; + this.FunctionArguments = other.FunctionArguments; + this.Value = other.Value; + this.Metadata = other.Metadata; + } + + public Role Role { get; set; } + + public string? Content { get; set; } + + public string? From { get; set; } + + public string? FunctionName { get; set; } + + public string? FunctionArguments { get; set; } + + /// + /// raw message + /// + public object? Value { get; set; } + + public IList> Metadata { get; set; } = new List>(); + + public override string ToString() + { + return $"Message({this.Role}, {this.Content}, {this.From}, {this.FunctionName}, {this.FunctionArguments})"; + } +} diff --git a/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs b/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs index 9cb957792b5..ba4fc2d6b4e 100644 --- a/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs +++ b/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs @@ -164,39 +164,53 @@ public static ChatRequestFunctionMessage ToChatRequestFunctionMessage(this Messa return functionMessage; } - public static IEnumerable ToOpenAIChatRequestMessage(this IAgent agent, IMessage message) + private static IMessage ToMessageEnvelope(this TMessage msg, string? from) + where TMessage : ChatRequestMessage { + return new MessageEnvelope(msg, from); + } + + public static IEnumerable> ToOpenAIChatRequestMessage(this IAgent agent, IMessage message) + { + if (message is IMessage oaiMessage) + { + return [oaiMessage]; + } + if (message.From != agent.Name) { if (message is TextMessage textMessage) { if (textMessage.Role == Role.System) { - return [new ChatRequestSystemMessage(textMessage.Content)]; + var msg = new ChatRequestSystemMessage(textMessage.Content); + return [msg.ToMessageEnvelope(message.From)]; } else { - return [new ChatRequestUserMessage(textMessage.Content)]; + var msg = new ChatRequestUserMessage(textMessage.Content); + return [msg.ToMessageEnvelope(message.From)]; } } + else if (message is ImageMessage imageMessage) + { + // multi-modal + var msg = new ChatRequestUserMessage(new ChatMessageImageContentItem(new Uri(imageMessage.Url))); + + return [msg.ToMessageEnvelope(message.From)]; + } else if (message is ToolCallMessage) { throw new ArgumentException($"ToolCallMessage is not supported when message.From is not the same with agent"); } else if (message is ToolCallResultMessage toolCallResult) { - return [new ChatRequestToolMessage(toolCallResult.Result, toolCallResult.ToolCallMessage.FunctionName)]; + var msg = new ChatRequestToolMessage(toolCallResult.Result, toolCallResult.ToolCallMessage.FunctionName); + return [msg.ToMessageEnvelope(message.From)]; } - else if (message is AggregateMessage aggregateMessage) + else if (message is MultiModalMessage multiModalMessage) { - // if aggreate message contains a list of tool call result message, then it is a parallel tool call message - if (aggregateMessage.Messages.All(m => m is ToolCallResultMessage)) - { - return aggregateMessage.Messages.Select(message => new ChatRequestToolMessage((message as ToolCallResultMessage)!.Result, (message as ToolCallResultMessage)!.ToolCallMessage.FunctionName)); - } - - // otherwise, it's a multi-modal message - IEnumerable messageContent = aggregateMessage.Messages.Select(m => + var messageContent = multiModalMessage.Content.Select(m => { return m switch { @@ -206,7 +220,46 @@ public static IEnumerable ToOpenAIChatRequestMessage(this IA }; }); - return [new ChatRequestUserMessage(messageContent)]; + var msg = new ChatRequestUserMessage(messageContent); + return [msg.ToMessageEnvelope(message.From)]; + } + else if (message is ParallelToolCallResultMessage parallelToolCallResultMessage) + { + return parallelToolCallResultMessage.ToolCallResult.Select(m => + { + var msg = new ChatRequestToolMessage(m.Result, m.ToolCallMessage.FunctionName); + + return msg.ToMessageEnvelope(message.From); + }); + } + else if (message is Message msg) + { + if (msg.Role == Role.System) + { + var systemMessage = new ChatRequestSystemMessage(msg.Content ?? string.Empty); + return [systemMessage.ToMessageEnvelope(message.From)]; + } + else if (msg.FunctionName is null && msg.FunctionArguments is null) + { + var userMessage = msg.ToChatRequestUserMessage(); + return [userMessage.ToMessageEnvelope(message.From)]; + } + else if (msg.FunctionName is not null && msg.FunctionArguments is not null && msg.Content is not null) + { + if (msg.Role == Role.Function) + { + return [new ChatRequestFunctionMessage(msg.FunctionName, msg.Content).ToMessageEnvelope(message.From)]; + } + else + { + return [new ChatRequestUserMessage(msg.Content).ToMessageEnvelope(message.From)]; + } + } + else + { + var userMessage = new ChatRequestUserMessage(msg.Content ?? throw new ArgumentException("Content is null")); + return [userMessage.ToMessageEnvelope(message.From)]; + } } else { @@ -222,7 +275,8 @@ public static IEnumerable ToOpenAIChatRequestMessage(this IA throw new ArgumentException("System message is not supported when message.From is the same with agent"); } - return [new ChatRequestAssistantMessage(textMessage.Content)]; + + return [new ChatRequestAssistantMessage(textMessage.Content).ToMessageEnvelope(message.From)]; } else if (message is ToolCallMessage toolCallMessage) { @@ -230,7 +284,7 @@ public static IEnumerable ToOpenAIChatRequestMessage(this IA var assistantMessage = new ChatRequestAssistantMessage(string.Empty); assistantMessage.ToolCalls.Add(new ChatCompletionsFunctionToolCall(toolCallMessage.FunctionName, toolCallMessage.FunctionName, toolCallMessage.FunctionArguments)); - return [assistantMessage]; + return [assistantMessage.ToMessageEnvelope(message.From)]; } else if (message is AggregateMessage aggregateMessage) { @@ -248,7 +302,34 @@ public static IEnumerable ToOpenAIChatRequestMessage(this IA } } - return [assistantMessage]; + return [assistantMessage.ToMessageEnvelope(message.From)]; + } + else if (message is Message msg) + { + if (msg.FunctionArguments is not null && msg.FunctionName is not null && msg.Content is not null) + { + var assistantMessage = new ChatRequestAssistantMessage(msg.Content); + assistantMessage.FunctionCall = new FunctionCall(msg.FunctionName, msg.FunctionArguments); + var functionCallMessage = new ChatRequestFunctionMessage(msg.FunctionName, msg.Content); + return [assistantMessage.ToMessageEnvelope(message.From), functionCallMessage.ToMessageEnvelope(message.From)]; + } + else + { + if (msg.Role == Role.Function) + { + return [new ChatRequestFunctionMessage(msg.FunctionName!, msg.Content!).ToMessageEnvelope(message.From)]; + } + else + { + var assistantMessage = new ChatRequestAssistantMessage(msg.Content!); + if (msg.FunctionName is not null && msg.FunctionArguments is not null) + { + assistantMessage.FunctionCall = new FunctionCall(msg.FunctionName, msg.FunctionArguments); + } + + return [assistantMessage.ToMessageEnvelope(message.From)]; + } + } } else { @@ -256,4 +337,14 @@ public static IEnumerable ToOpenAIChatRequestMessage(this IA } } } + + public static IEnumerable ToAutoGenMessages(this IAgent agent, IEnumerable> openaiMessages) + { + throw new NotImplementedException(); + } + + public static IMessage ToAutoGenMessage(ChatRequestMessage openaiMessage, string? from = null) + { + throw new NotImplementedException(); + } } diff --git a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs index b5fb55256f2..34d482703da 100644 --- a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs +++ b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs @@ -96,7 +96,7 @@ public class GPTAgent : IStreamingReplyAgent return this.ProcessResponse(response); } - private async IAsyncEnumerable ProcessResponse(StreamingResponse response) + private async IAsyncEnumerable ProcessResponse(StreamingResponse response) { var content = string.Empty; string? functionName = default; @@ -136,8 +136,7 @@ await foreach (var chunk in response) // in this case we yield the message if (content is not null && functionName is null) { - var msg = new Message(Role.Assistant, content, from: this.Name); - msg.Metadata.Add(new KeyValuePair(CHUNK_KEY, chunk!)); + var msg = new TextMessage(Role.Assistant, content, from: this.Name); yield return msg; continue; @@ -147,12 +146,8 @@ await foreach (var chunk in response) // in this case, we yield the message once after function name is available and function args has been updated if (functionName is not null && functionArguments is not null) { - var msg = new Message(Role.Assistant, null, from: this.Name) - { - FunctionName = functionName, - FunctionArguments = functionArguments, - }; - msg.Metadata.Add(new KeyValuePair(CHUNK_KEY, chunk!)); + var msg = new ToolCallMessage(functionName, functionArguments, from: this.Name); + yield return msg; if (functionMap is not null && chunk?.FinishReason is not null && chunk.FinishReason == CompletionsFinishReason.FunctionCall) { @@ -160,19 +155,13 @@ await foreach (var chunk in response) if (this.functionMap.TryGetValue(functionName, out var func)) { var result = await func(functionArguments); - msg.Content = result; + yield return new ToolCallResultMessage(result, msg, from: this.Name); } else { var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}"; - msg.Content = errorMessage; + yield return new ToolCallResultMessage(errorMessage, msg, from: this.Name); } - - yield return msg; - } - else - { - yield return msg; } continue; @@ -211,68 +200,17 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions } - private IEnumerable ProcessMessages(IEnumerable messages) + private IEnumerable ProcessMessages(IEnumerable messages) { // add system message if there's no system message in messages - if (!messages.Any(m => m.Role == Role.System)) + var openAIMessages = messages.SelectMany(m => this.ToOpenAIChatRequestMessage(m)) + .Select(m => m.Content) ?? []; + if (!openAIMessages.Any(m => m is ChatRequestSystemMessage)) { - messages = new[] { new Message(Role.System, _systemMessage) }.Concat(messages); + openAIMessages = new[] { new ChatRequestSystemMessage(_systemMessage) }.Concat(openAIMessages); } - var i = 0; - foreach (var message in messages) - { - if (message.Role == Role.System || message.From is null) - { - if (message.Role == Role.System) - { - // add as system message - yield return message.ToChatRequestSystemMessage(); - } - else - { - // add as user message - yield return message.ToChatRequestUserMessage(); - } - } - else if (message.From != this.Name) - { - if (message.Role == Role.Function) - { - yield return message.ToChatRequestFunctionMessage(); - } - else - { - yield return message.ToChatRequestUserMessage(); - } - } - else - { - if (message.FunctionArguments is string functionArguments && message.FunctionName is string functionName && message.Content is string) - { - i++; - - yield return message.ToChatRequestAssistantMessage(); - - var functionResultMessage = new ChatRequestFunctionMessage(functionName, message.Content); - - yield return message.ToChatRequestFunctionMessage(); - i++; - } - else - { - i++; - if (message.Role == Role.Function) - { - yield return message.ToChatRequestFunctionMessage(); - } - else - { - yield return message.ToChatRequestAssistantMessage(); - } - } - } - } + return openAIMessages; } private async Task PostProcessMessage(ChatResponseMessage oaiMessage) diff --git a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt b/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt new file mode 100644 index 00000000000..b01d87819bd --- /dev/null +++ b/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt @@ -0,0 +1,195 @@ +[ + { + "OriginalMessage": "TextMessage(system, You are a helpful AI assistant, )", + "ConvertedMessages": [ + { + "Role": "system", + "Content": "You are a helpful AI assistant" + } + ] + }, + { + "OriginalMessage": "TextMessage(user, Hello, user)", + "ConvertedMessages": [ + { + "Role": "user", + "Content": "Hello", + "MultiModaItem": null + } + ] + }, + { + "OriginalMessage": "TextMessage(assistant, How can I help you?, assistant)", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": "How can I help you?", + "TooCall": [], + "FunctionCallName": null, + "FunctionCallArguments": null + } + ] + }, + { + "OriginalMessage": "Message(system, You are a helpful AI assistant, , , )", + "ConvertedMessages": [ + { + "Role": "system", + "Content": "You are a helpful AI assistant" + } + ] + }, + { + "OriginalMessage": "Message(user, Hello, user, , )", + "ConvertedMessages": [ + { + "Role": "user", + "Content": "Hello", + "MultiModaItem": null + } + ] + }, + { + "OriginalMessage": "Message(assistant, How can I help you?, assistant, , )", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": "How can I help you?", + "TooCall": [], + "FunctionCallName": null, + "FunctionCallArguments": null + } + ] + }, + { + "OriginalMessage": "Message(function, result, user, , )", + "ConvertedMessages": [ + { + "Role": "user", + "Content": "result", + "MultiModaItem": null + } + ] + }, + { + "OriginalMessage": "Message(assistant, , assistant, functionName, functionArguments)", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": null, + "TooCall": [], + "FunctionCallName": "functionName", + "FunctionCallArguments": "functionArguments" + } + ] + }, + { + "OriginalMessage": "ImageMessage(user, https://example.com/image.png, user)", + "ConvertedMessages": [ + { + "Role": "user", + "Content": null, + "MultiModaItem": [ + { + "Type": "Image", + "ImageUrl": { + "Url": "https://example.com/image.png", + "Detail": null + } + } + ] + } + ] + }, + { + "OriginalMessage": "MultiModalMessage(, user)\n\tTextMessage(user, Hello, user)\n\tImageMessage(user, https://example.com/image.png, user)", + "ConvertedMessages": [ + { + "Role": "user", + "Content": null, + "MultiModaItem": [ + { + "Type": "Text", + "Text": "Hello" + }, + { + "Type": "Image", + "ImageUrl": { + "Url": "https://example.com/image.png", + "Detail": null + } + } + ] + } + ] + }, + { + "OriginalMessage": "ToolCallMessage(test, test, assistant)", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": "", + "TooCall": [ + { + "Type": "Function", + "Name": "test", + "Arguments": "test", + "Id": "test" + } + ], + "FunctionCallName": null, + "FunctionCallArguments": null + } + ] + }, + { + "OriginalMessage": "ToolCallResultMessage(result, ToolCallMessage(test, test, assistant), user)", + "ConvertedMessages": [ + { + "Role": "tool", + "Content": "result", + "ToolCallId": "test" + } + ] + }, + { + "OriginalMessage": "ParallelToolCallResultMessage(user)\n\tToolCallResultMessage(result, ToolCallMessage(test, test, assistant), user)\n\tToolCallResultMessage(result, ToolCallMessage(test, test, assistant), user)", + "ConvertedMessages": [ + { + "Role": "tool", + "Content": "result", + "ToolCallId": "test" + }, + { + "Role": "tool", + "Content": "result", + "ToolCallId": "test" + } + ] + }, + { + "OriginalMessage": "AggregateMessage(assistant)\n\tToolCallMessage(test, test, assistant)\n\tToolCallMessage(test, test, assistant)", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": "", + "TooCall": [ + { + "Type": "Function", + "Name": "test", + "Arguments": "test", + "Id": "test" + }, + { + "Type": "Function", + "Name": "test", + "Arguments": "test", + "Id": "test" + } + ], + "FunctionCallName": null, + "FunctionCallArguments": null + } + ] + } +] \ No newline at end of file diff --git a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs new file mode 100644 index 00000000000..f9235effdfb --- /dev/null +++ b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OpenAIMessageTests.cs + +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using ApprovalTests; +using ApprovalTests.Namers; +using ApprovalTests.Reporters; +using AutoGen.OpenAI; +using Azure.AI.OpenAI; +using Xunit; + +namespace AutoGen.Tests; + +public class OpenAIMessageTests +{ + private readonly JsonSerializerOptions jsonSerializerOptions = new JsonSerializerOptions + { + WriteIndented = true, + IgnoreReadOnlyProperties = false, + }; + + [Fact] + [UseReporter(typeof(DiffReporter))] + [UseApprovalSubdirectory("ApprovalTests")] + public void BasicMessageTest() + { + IMessage[] messages = [ + new TextMessage(Role.System, "You are a helpful AI assistant"), + new TextMessage(Role.User, "Hello", "user"), + new TextMessage(Role.Assistant, "How can I help you?", from: "assistant"), + new Message(Role.System, "You are a helpful AI assistant"), + new Message(Role.User, "Hello", "user"), + new Message(Role.Assistant, "How can I help you?", from: "assistant"), + new Message(Role.Function, "result", "user"), + new Message(Role.Assistant, null, "assistant") + { + FunctionName = "functionName", + FunctionArguments = "functionArguments", + }, + new ImageMessage(Role.User, "https://example.com/image.png", "user"), + new MultiModalMessage( + [ + new TextMessage(Role.User, "Hello", "user"), + new ImageMessage(Role.User, "https://example.com/image.png", "user"), + ], "user"), + new ToolCallMessage("test", "test", "assistant"), + new ToolCallResultMessage("result", new ToolCallMessage("test", "test", "assistant"), "user"), + new ParallelToolCallResultMessage( + [ + new ToolCallResultMessage("result", new ToolCallMessage("test", "test", "assistant"), "user"), + new ToolCallResultMessage("result", new ToolCallMessage("test", "test", "assistant"), "user"), + ], "user"), + new AggregateMessage( + [ + new ToolCallMessage("test", "test", "assistant"), + new ToolCallMessage("test", "test", "assistant"), + ], "assistant"), + ]; + + var agent = new EchoAgent("assistant"); + + var oaiMessages = messages.Select(m => (m, agent.ToOpenAIChatRequestMessage(m).Select(m => m.Content))); + VerifyOAIMessages(oaiMessages); + } + + private void VerifyOAIMessages(IEnumerable<(IMessage, IEnumerable)> messages) + { + var jsonObjects = messages.Select(pair => + { + var (originalMessage, ms) = pair; + var objs = new List(); + foreach (var m in ms) + { + object? obj = null; + if (m is ChatRequestUserMessage userMessage) + { + obj = new + { + Role = userMessage.Role.ToString(), + Content = userMessage.Content, + MultiModaItem = userMessage.MultimodalContentItems?.Select(item => + { + return item switch + { + ChatMessageImageContentItem imageContentItem => new + { + Type = "Image", + ImageUrl = imageContentItem.ImageUrl, + } as object, + ChatMessageTextContentItem textContentItem => new + { + Type = "Text", + Text = textContentItem.Text, + } as object, + _ => throw new System.NotImplementedException(), + }; + }), + }; + } + + if (m is ChatRequestAssistantMessage assistantMessage) + { + obj = new + { + Role = assistantMessage.Role.ToString(), + Content = assistantMessage.Content, + TooCall = assistantMessage.ToolCalls.Select(tc => + { + return tc switch + { + ChatCompletionsFunctionToolCall functionToolCall => new + { + Type = "Function", + Name = functionToolCall.Name, + Arguments = functionToolCall.Arguments, + Id = functionToolCall.Id, + } as object, + _ => throw new System.NotImplementedException(), + }; + }), + FunctionCallName = assistantMessage.FunctionCall?.Name, + FunctionCallArguments = assistantMessage.FunctionCall?.Arguments, + }; + } + + if (m is ChatRequestSystemMessage systemMessage) + { + obj = new + { + Role = systemMessage.Role.ToString(), + Content = systemMessage.Content, + }; + } + + if (m is ChatRequestFunctionMessage functionMessage) + { + obj = new + { + Role = functionMessage.Role.ToString(), + Content = functionMessage.Content, + Name = functionMessage.Name, + }; + } + + if (m is ChatRequestToolMessage toolCallMessage) + { + obj = new + { + Role = toolCallMessage.Role.ToString(), + Content = toolCallMessage.Content, + ToolCallId = toolCallMessage.ToolCallId, + }; + } + + objs.Add(obj ?? throw new System.NotImplementedException()); + } + + return new + { + OriginalMessage = originalMessage.ToString(), + ConvertedMessages = objs, + }; + }); + + var json = JsonSerializer.Serialize(jsonObjects, this.jsonSerializerOptions); + Approvals.Verify(json); + } +} diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index dc5bce06445..361855150b3 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -236,22 +236,23 @@ private async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingReplyAg }; var replyStream = await agent.GenerateReplyStreamingAsync(messages: new Message[] { message, helloWorld }, option); var answer = "[ECHO] Hello world"; - Message? finalReply = default; + IMessage? finalReply = default; await foreach (var reply in replyStream) { - reply.Role.Should().Be(Role.Assistant); reply.From.Should().Be(agent.Name); - finalReply = reply; - - var formatted = reply.FormatMessage(); - _output.WriteLine(formatted); } - finalReply!.Content.Should().Be(answer); - finalReply!.Role.Should().Be(Role.Assistant); - finalReply!.From.Should().Be(agent.Name); - finalReply!.FunctionName.Should().Be(nameof(EchoAsync)); + if (finalReply is ToolCallResultMessage toolCallResultMessage) + { + toolCallResultMessage.Result.Should().Be(answer); + toolCallResultMessage.From.Should().Be(agent.Name); + toolCallResultMessage.ToolCallMessage.FunctionName.Should().Be(nameof(EchoAsync)); + } + else + { + throw new Exception("unexpected message type"); + } } private async Task UpperCaseTest(IAgent agent) @@ -276,15 +277,21 @@ private async Task UpperCaseStreamingTestAsync(IStreamingReplyAgent agent) }; var replyStream = await agent.GenerateReplyStreamingAsync(messages: new Message[] { message, helloWorld }, option); var answer = "ABCDEFG"; - Message? finalReply = default; + TextMessage? finalReply = default; await foreach (var reply in replyStream) { - reply.Role.Should().Be(Role.Assistant); - reply.From.Should().Be(agent.Name); + if (reply is TextMessage textMessage) + { + textMessage.From.Should().Be(agent.Name); - // the content should be part of the answer - reply.Content.Should().Be(answer.Substring(0, reply.Content!.Length)); - finalReply = reply; + // the content should be part of the answer + textMessage.Content.Should().Be(answer.Substring(0, textMessage.Content!.Length)); + finalReply = textMessage; + + continue; + } + + throw new Exception("unexpected message type"); } finalReply!.Content.Should().Be(answer); From e9a472f3df97f758920e8d37edc34bd1ef7afd0f Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 13 Feb 2024 22:33:29 -0800 Subject: [PATCH 03/27] add more tests --- dotnet/src/AutoGen/Core/MessageEnvelope.cs | 314 ------------------ dotnet/src/AutoGen/Core/Role.cs | 54 --- .../OpenAI/Extension/MessageExtension.cs | 92 ++--- dotnet/src/AutoGen/OpenAI/GPTAgent.cs | 7 +- ...MessageTests.BasicMessageTest.approved.txt | 40 ++- .../test/AutoGen.Tests/OpenAIMessageTests.cs | 221 +++++++++++- dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 5 +- 7 files changed, 298 insertions(+), 435 deletions(-) delete mode 100644 dotnet/src/AutoGen/Core/MessageEnvelope.cs delete mode 100644 dotnet/src/AutoGen/Core/Role.cs diff --git a/dotnet/src/AutoGen/Core/MessageEnvelope.cs b/dotnet/src/AutoGen/Core/MessageEnvelope.cs deleted file mode 100644 index b04bae44f44..00000000000 --- a/dotnet/src/AutoGen/Core/MessageEnvelope.cs +++ /dev/null @@ -1,314 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Message.cs - -using System; -using System.Collections.Generic; -using Azure.AI.OpenAI; - -namespace AutoGen; - -public interface IMessage -{ - string? From { get; set; } -} - -public interface IMessage : IMessage -{ - T Content { get; } -} - -internal class MessageEnvelope : IMessage -{ - public MessageEnvelope(T content, string? from = null, IDictionary? metadata = null) - { - this.Content = content; - this.From = from; - this.Metadata = metadata ?? new Dictionary(); - } - - public T Content { get; } - - public string? From { get; set; } - - public IDictionary Metadata { get; set; } -} - -public class TextMessage : IMessage -{ - public TextMessage(Role role, string content, string? from = null) - { - this.Content = content; - this.Role = role; - this.From = from; - } - - public Role Role { get; set; } - - public string Content { get; } - - public string? From { get; set; } - - public override string ToString() - { - return $"TextMessage({this.Role}, {this.Content}, {this.From})"; - } -} - -public class ImageMessage : IMessage -{ - public ImageMessage(Role role, string url, string? from = null) - { - this.Role = role; - this.From = from; - this.Url = url; - } - - public Role Role { get; set; } - - public string Url { get; set; } - - public string? From { get; set; } - - public override string ToString() - { - return $"ImageMessage({this.Role}, {this.Url}, {this.From})"; - } -} - -public class ToolCallMessage : IMessage -{ - public ToolCallMessage(string functionName, string functionArgs, string? from = null) - { - this.From = from; - this.FunctionName = functionName; - this.FunctionArguments = functionArgs; - } - - public string FunctionName { get; set; } - - public string FunctionArguments { get; set; } - - public string? From { get; set; } - - public override string ToString() - { - return $"ToolCallMessage({this.FunctionName}, {this.FunctionArguments}, {this.From})"; - } -} - -public class ToolCallResultMessage : IMessage -{ - public ToolCallResultMessage(string result, ToolCallMessage toolCallMessage, string? from = null) - { - this.From = from; - this.Result = result; - this.ToolCallMessage = toolCallMessage; - } - - /// - /// The original tool call message - /// - public ToolCallMessage ToolCallMessage { get; set; } - - /// - /// The result from the tool call - /// - public string Result { get; set; } - - public string? From { get; set; } - - public override string ToString() - { - return $"ToolCallResultMessage({this.Result}, {this.ToolCallMessage}, {this.From})"; - } -} - -public class MultiModalMessage : IMessage -{ - public MultiModalMessage(IEnumerable content, string? from = null) - { - this.Content = content; - this.From = from; - this.Validate(); - } - - public Role Role { get; set; } - - public IEnumerable Content { get; set; } - - public string? From { get; set; } - - private void Validate() - { - foreach (var message in this.Content) - { - if (message.From != this.From) - { - var reason = $"The from property of the message {message} is different from the from property of the aggregate message {this}"; - throw new ArgumentException($"Invalid aggregate message {reason}"); - } - } - - // all message must be either text or image - foreach (var message in this.Content) - { - if (message is not TextMessage && message is not ImageMessage) - { - var reason = $"The message {message} is not a text or image message"; - throw new ArgumentException($"Invalid aggregate message {reason}"); - } - } - } - - public override string ToString() - { - var stringBuilder = new System.Text.StringBuilder(); - stringBuilder.Append($"MultiModalMessage({this.Role}, {this.From})"); - foreach (var message in this.Content) - { - stringBuilder.Append($"\n\t{message}"); - } - - return stringBuilder.ToString(); - } -} - -public class ParallelToolCallResultMessage : IMessage -{ - public ParallelToolCallResultMessage(IEnumerable toolCallResult, string? from = null) - { - this.ToolCallResult = toolCallResult; - this.From = from; - this.Validate(); - } - - public IEnumerable ToolCallResult { get; set; } - - public string? From { get; set; } - - public bool Equals(IMessage other) - { - throw new NotImplementedException(); - } - - private void Validate() - { - // the from property of all messages should be the same with the from property of the aggregate message - foreach (var message in this.ToolCallResult) - { - if (message.From != this.From) - { - var reason = $"The from property of the message {message} is different from the from property of the aggregate message {this}"; - throw new ArgumentException($"Invalid aggregate message {reason}"); - } - } - } - - public override string ToString() - { - var stringBuilder = new System.Text.StringBuilder(); - stringBuilder.Append($"ParallelToolCallResultMessage({this.From})"); - foreach (var message in this.ToolCallResult) - { - stringBuilder.Append($"\n\t{message}"); - } - - return stringBuilder.ToString(); - } -} - -public class AggregateMessage : IMessage -{ - public AggregateMessage(IEnumerable messages, string? from = null) - { - this.From = from; - this.Messages = messages; - this.Validate(); - } - - public IEnumerable Messages { get; set; } - - public string? From { get; set; } - - private void Validate() - { - // the from property of all messages should be the same with the from property of the aggregate message - foreach (var message in this.Messages) - { - if (message.From != this.From) - { - var reason = $"The from property of the message {message} is different from the from property of the aggregate message {this}"; - throw new ArgumentException($"Invalid aggregate message {reason}"); - } - } - - // no nested aggregate message - foreach (var message in this.Messages) - { - if (message is AggregateMessage) - { - var reason = $"The message {message} is an aggregate message"; - throw new ArgumentException("Invalid aggregate message " + reason); - } - } - } - - public override string ToString() - { - var stringBuilder = new System.Text.StringBuilder(); - stringBuilder.Append($"AggregateMessage({this.From})"); - foreach (var message in this.Messages) - { - stringBuilder.Append($"\n\t{message}"); - } - - return stringBuilder.ToString(); - } -} - -public class Message : IMessage -{ - public Message( - Role role, - string? content, - string? from = null, - FunctionCall? functionCall = null) - { - this.Role = role; - this.Content = content; - this.From = from; - this.FunctionName = functionCall?.Name; - this.FunctionArguments = functionCall?.Arguments; - } - - public Message(Message other) - : this(other.Role, other.Content, other.From) - { - this.FunctionName = other.FunctionName; - this.FunctionArguments = other.FunctionArguments; - this.Value = other.Value; - this.Metadata = other.Metadata; - } - - public Role Role { get; set; } - - public string? Content { get; set; } - - public string? From { get; set; } - - public string? FunctionName { get; set; } - - public string? FunctionArguments { get; set; } - - /// - /// raw message - /// - public object? Value { get; set; } - - public IList> Metadata { get; set; } = new List>(); - - public override string ToString() - { - return $"Message({this.Role}, {this.Content}, {this.From}, {this.FunctionName}, {this.FunctionArguments})"; - } -} diff --git a/dotnet/src/AutoGen/Core/Role.cs b/dotnet/src/AutoGen/Core/Role.cs deleted file mode 100644 index 4be88007ae9..00000000000 --- a/dotnet/src/AutoGen/Core/Role.cs +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Role.cs - -using System; - -namespace AutoGen; - -public readonly struct Role : IEquatable -{ - private readonly string label; - - internal Role(string name) - { - label = name; - } - - public static Role User { get; } = new Role("user"); - - public static Role Assistant { get; } = new Role("assistant"); - - public static Role System { get; } = new Role("system"); - - public static Role Function { get; } = new Role("function"); - - public bool Equals(Role other) - { - return label.Equals(other.label, StringComparison.OrdinalIgnoreCase); - } - - public override string ToString() - { - return label; - } - - public override bool Equals(object? obj) - { - return obj is Role other && Equals(other); - } - - public override int GetHashCode() - { - return label.GetHashCode(); - } - - public static bool operator ==(Role left, Role right) - { - return left.Equals(right); - } - - public static bool operator !=(Role left, Role right) - { - return !(left == right); - } -} diff --git a/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs b/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs index ba4fc2d6b4e..45171091168 100644 --- a/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs +++ b/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs @@ -164,17 +164,12 @@ public static ChatRequestFunctionMessage ToChatRequestFunctionMessage(this Messa return functionMessage; } - private static IMessage ToMessageEnvelope(this TMessage msg, string? from) - where TMessage : ChatRequestMessage - { - return new MessageEnvelope(msg, from); - } - - public static IEnumerable> ToOpenAIChatRequestMessage(this IAgent agent, IMessage message) + public static IEnumerable ToOpenAIChatRequestMessage(this IAgent agent, IMessage message) { if (message is IMessage oaiMessage) { - return [oaiMessage]; + // short-circuit + return [oaiMessage.Content]; } if (message.From != agent.Name) @@ -184,12 +179,13 @@ public static IEnumerable> ToOpenAIChatRequestMessa if (textMessage.Role == Role.System) { var msg = new ChatRequestSystemMessage(textMessage.Content); - return [msg.ToMessageEnvelope(message.From)]; + + return [msg]; } else { var msg = new ChatRequestUserMessage(textMessage.Content); - return [msg.ToMessageEnvelope(message.From)]; + return [msg]; } } else if (message is ImageMessage imageMessage) @@ -197,7 +193,7 @@ public static IEnumerable> ToOpenAIChatRequestMessa // multi-modal var msg = new ChatRequestUserMessage(new ChatMessageImageContentItem(new Uri(imageMessage.Url))); - return [msg.ToMessageEnvelope(message.From)]; + return [msg]; } else if (message is ToolCallMessage) { @@ -205,8 +201,12 @@ public static IEnumerable> ToOpenAIChatRequestMessa } else if (message is ToolCallResultMessage toolCallResult) { - var msg = new ChatRequestToolMessage(toolCallResult.Result, toolCallResult.ToolCallMessage.FunctionName); - return [msg.ToMessageEnvelope(message.From)]; + return toolCallResult.ToolCalls.Select(m => + { + var msg = new ChatRequestToolMessage(m.Result, m.FunctionName); + + return msg; + }); } else if (message is MultiModalMessage multiModalMessage) { @@ -221,44 +221,41 @@ public static IEnumerable> ToOpenAIChatRequestMessa }); var msg = new ChatRequestUserMessage(messageContent); - return [msg.ToMessageEnvelope(message.From)]; + return [msg]; } - else if (message is ParallelToolCallResultMessage parallelToolCallResultMessage) + else if (message is AggregateMessage aggregateMessage) { - return parallelToolCallResultMessage.ToolCallResult.Select(m => - { - var msg = new ChatRequestToolMessage(m.Result, m.ToolCallMessage.FunctionName); - - return msg.ToMessageEnvelope(message.From); - }); + // convert as user message + var resultMessage = aggregateMessage.Message2; + return resultMessage.ToolCalls.Select(m => new ChatRequestUserMessage(m.Result)); } else if (message is Message msg) { if (msg.Role == Role.System) { var systemMessage = new ChatRequestSystemMessage(msg.Content ?? string.Empty); - return [systemMessage.ToMessageEnvelope(message.From)]; + return [systemMessage]; } else if (msg.FunctionName is null && msg.FunctionArguments is null) { var userMessage = msg.ToChatRequestUserMessage(); - return [userMessage.ToMessageEnvelope(message.From)]; + return [userMessage]; } else if (msg.FunctionName is not null && msg.FunctionArguments is not null && msg.Content is not null) { if (msg.Role == Role.Function) { - return [new ChatRequestFunctionMessage(msg.FunctionName, msg.Content).ToMessageEnvelope(message.From)]; + return [new ChatRequestFunctionMessage(msg.FunctionName, msg.Content)]; } else { - return [new ChatRequestUserMessage(msg.Content).ToMessageEnvelope(message.From)]; + return [new ChatRequestUserMessage(msg.Content)]; } } else { var userMessage = new ChatRequestUserMessage(msg.Content ?? throw new ArgumentException("Content is null")); - return [userMessage.ToMessageEnvelope(message.From)]; + return [userMessage]; } } else @@ -276,33 +273,38 @@ public static IEnumerable> ToOpenAIChatRequestMessa } - return [new ChatRequestAssistantMessage(textMessage.Content).ToMessageEnvelope(message.From)]; + return [new ChatRequestAssistantMessage(textMessage.Content)]; } else if (message is ToolCallMessage toolCallMessage) { - // single tool call message var assistantMessage = new ChatRequestAssistantMessage(string.Empty); - assistantMessage.ToolCalls.Add(new ChatCompletionsFunctionToolCall(toolCallMessage.FunctionName, toolCallMessage.FunctionName, toolCallMessage.FunctionArguments)); + var toolCalls = toolCallMessage.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments)); + foreach (var tc in toolCalls) + { + assistantMessage.ToolCalls.Add(tc); + } - return [assistantMessage.ToMessageEnvelope(message.From)]; + return [assistantMessage]; } - else if (message is AggregateMessage aggregateMessage) + else if (message is AggregateMessage aggregateMessage) { - // parallel tool call messages + var toolCallMessage1 = aggregateMessage.Message1; + var toolCallResultMessage = aggregateMessage.Message2; + var assistantMessage = new ChatRequestAssistantMessage(string.Empty); - foreach (var m in aggregateMessage.Messages) + var toolCalls = toolCallMessage1.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments)); + foreach (var tc in toolCalls) { - if (m is ToolCallMessage toolCall) - { - assistantMessage.ToolCalls.Add(new ChatCompletionsFunctionToolCall(toolCall.FunctionName, toolCall.FunctionName, toolCall.FunctionArguments)); - } - else - { - throw new ArgumentException($"Unknown message type: {m.GetType()}"); - } + assistantMessage.ToolCalls.Add(tc); } - return [assistantMessage.ToMessageEnvelope(message.From)]; + var toolCallResults = toolCallResultMessage.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); + + // return assistantMessage and tool call result messages + var messages = new List { assistantMessage }; + messages.AddRange(toolCallResults); + + return messages; } else if (message is Message msg) { @@ -311,13 +313,13 @@ public static IEnumerable> ToOpenAIChatRequestMessa var assistantMessage = new ChatRequestAssistantMessage(msg.Content); assistantMessage.FunctionCall = new FunctionCall(msg.FunctionName, msg.FunctionArguments); var functionCallMessage = new ChatRequestFunctionMessage(msg.FunctionName, msg.Content); - return [assistantMessage.ToMessageEnvelope(message.From), functionCallMessage.ToMessageEnvelope(message.From)]; + return [assistantMessage, functionCallMessage]; } else { if (msg.Role == Role.Function) { - return [new ChatRequestFunctionMessage(msg.FunctionName!, msg.Content!).ToMessageEnvelope(message.From)]; + return [new ChatRequestFunctionMessage(msg.FunctionName!, msg.Content!)]; } else { @@ -327,7 +329,7 @@ public static IEnumerable> ToOpenAIChatRequestMessa assistantMessage.FunctionCall = new FunctionCall(msg.FunctionName, msg.FunctionArguments); } - return [assistantMessage.ToMessageEnvelope(message.From)]; + return [assistantMessage]; } } } diff --git a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs index 34d482703da..0d4a37322cd 100644 --- a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs +++ b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs @@ -155,12 +155,12 @@ await foreach (var chunk in response) if (this.functionMap.TryGetValue(functionName, out var func)) { var result = await func(functionArguments); - yield return new ToolCallResultMessage(result, msg, from: this.Name); + yield return new ToolCallResultMessage(result, functionName, functionArguments, from: this.Name); } else { var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}"; - yield return new ToolCallResultMessage(errorMessage, msg, from: this.Name); + yield return new ToolCallResultMessage(errorMessage, functionName, functionArguments, from: this.Name); } } @@ -203,8 +203,7 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions private IEnumerable ProcessMessages(IEnumerable messages) { // add system message if there's no system message in messages - var openAIMessages = messages.SelectMany(m => this.ToOpenAIChatRequestMessage(m)) - .Select(m => m.Content) ?? []; + var openAIMessages = messages.SelectMany(m => this.ToOpenAIChatRequestMessage(m)) ?? []; if (!openAIMessages.Any(m => m is ChatRequestSystemMessage)) { openAIMessages = new[] { new ChatRequestSystemMessage(_systemMessage) }.Concat(openAIMessages); diff --git a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt b/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt index b01d87819bd..aeb5add2293 100644 --- a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt +++ b/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt @@ -124,7 +124,7 @@ ] }, { - "OriginalMessage": "ToolCallMessage(test, test, assistant)", + "OriginalMessage": "ToolCallMessage(assistant)\n\tToolCall(test, test, )", "ConvertedMessages": [ { "Role": "assistant", @@ -143,7 +143,7 @@ ] }, { - "OriginalMessage": "ToolCallResultMessage(result, ToolCallMessage(test, test, assistant), user)", + "OriginalMessage": "ToolCallResultMessage(user)\n\tToolCall(test, test, result)", "ConvertedMessages": [ { "Role": "tool", @@ -153,22 +153,22 @@ ] }, { - "OriginalMessage": "ParallelToolCallResultMessage(user)\n\tToolCallResultMessage(result, ToolCallMessage(test, test, assistant), user)\n\tToolCallResultMessage(result, ToolCallMessage(test, test, assistant), user)", + "OriginalMessage": "ToolCallResultMessage(user)\n\tToolCall(result, test, test)\n\tToolCall(result, test, test)", "ConvertedMessages": [ { "Role": "tool", - "Content": "result", - "ToolCallId": "test" + "Content": "test", + "ToolCallId": "result" }, { "Role": "tool", - "Content": "result", - "ToolCallId": "test" + "Content": "test", + "ToolCallId": "result" } ] }, { - "OriginalMessage": "AggregateMessage(assistant)\n\tToolCallMessage(test, test, assistant)\n\tToolCallMessage(test, test, assistant)", + "OriginalMessage": "ToolCallMessage(assistant)\n\tToolCall(test, test, )\n\tToolCall(test, test, )", "ConvertedMessages": [ { "Role": "assistant", @@ -191,5 +191,29 @@ "FunctionCallArguments": null } ] + }, + { + "OriginalMessage": "AggregateMessage(assistant)\n\tToolCallMessage(assistant)\n\tToolCall(test, test, )\n\tToolCallResultMessage(assistant)\n\tToolCall(test, test, result)", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": "", + "TooCall": [ + { + "Type": "Function", + "Name": "test", + "Arguments": "test", + "Id": "test" + } + ], + "FunctionCallName": null, + "FunctionCallArguments": null + }, + { + "Role": "tool", + "Content": "result", + "ToolCallId": "test" + } + ] } ] \ No newline at end of file diff --git a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs index f9235effdfb..baa6507e412 100644 --- a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs +++ b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // OpenAIMessageTests.cs +using System; using System.Collections.Generic; using System.Linq; using System.Text.Json; @@ -9,6 +10,7 @@ using ApprovalTests.Reporters; using AutoGen.OpenAI; using Azure.AI.OpenAI; +using FluentAssertions; using Xunit; namespace AutoGen.Tests; @@ -46,25 +48,228 @@ public void BasicMessageTest() new ImageMessage(Role.User, "https://example.com/image.png", "user"), ], "user"), new ToolCallMessage("test", "test", "assistant"), - new ToolCallResultMessage("result", new ToolCallMessage("test", "test", "assistant"), "user"), - new ParallelToolCallResultMessage( + new ToolCallResultMessage("result", "test", "test", "user"), + new ToolCallResultMessage( [ - new ToolCallResultMessage("result", new ToolCallMessage("test", "test", "assistant"), "user"), - new ToolCallResultMessage("result", new ToolCallMessage("test", "test", "assistant"), "user"), + new ToolCall("result", "test", "test"), + new ToolCall("result", "test", "test"), ], "user"), - new AggregateMessage( + new ToolCallMessage( [ - new ToolCallMessage("test", "test", "assistant"), - new ToolCallMessage("test", "test", "assistant"), + new ToolCall("test", "test"), + new ToolCall("test", "test"), ], "assistant"), + new AggregateMessage( + message1: new ToolCallMessage("test", "test", "assistant"), + message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"), ]; var agent = new EchoAgent("assistant"); - var oaiMessages = messages.Select(m => (m, agent.ToOpenAIChatRequestMessage(m).Select(m => m.Content))); + var oaiMessages = messages.Select(m => (m, agent.ToOpenAIChatRequestMessage(m))); VerifyOAIMessages(oaiMessages); } + [Fact] + public void ToOpenAIChatRequestMessageTest() + { + var agent = new EchoAgent("assistant"); + + // user message + IMessage message = new TextMessage(Role.User, "Hello", "user"); + var oaiMessages = agent.ToOpenAIChatRequestMessage(message); + + oaiMessages.Count().Should().Be(1); + oaiMessages.First().Should().BeOfType(); + var userMessage = (ChatRequestUserMessage)oaiMessages.First(); + userMessage.Content.Should().Be("Hello"); + + // user message test 2 + // even if Role is assistant, it should be converted to user message because it is from the user + message = new TextMessage(Role.Assistant, "Hello", "user"); + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + + oaiMessages.Count().Should().Be(1); + oaiMessages.First().Should().BeOfType(); + userMessage = (ChatRequestUserMessage)oaiMessages.First(); + userMessage.Content.Should().Be("Hello"); + + // user message with multimodal content + // image + message = new ImageMessage(Role.User, "https://example.com/image.png", "user"); + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + + oaiMessages.Count().Should().Be(1); + oaiMessages.First().Should().BeOfType(); + userMessage = (ChatRequestUserMessage)oaiMessages.First(); + userMessage.Content.Should().BeNullOrEmpty(); + userMessage.MultimodalContentItems.Count().Should().Be(1); + userMessage.MultimodalContentItems.First().Should().BeOfType(); + + // text and image + message = new MultiModalMessage( + [ + new TextMessage(Role.User, "Hello", "user"), + new ImageMessage(Role.User, "https://example.com/image.png", "user"), + ], "user"); + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + + oaiMessages.Count().Should().Be(1); + oaiMessages.First().Should().BeOfType(); + userMessage = (ChatRequestUserMessage)oaiMessages.First(); + userMessage.Content.Should().BeNullOrEmpty(); + userMessage.MultimodalContentItems.Count().Should().Be(2); + userMessage.MultimodalContentItems.First().Should().BeOfType(); + + // assistant text message + message = new TextMessage(Role.Assistant, "How can I help you?", "assistant"); + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + + oaiMessages.Count().Should().Be(1); + oaiMessages.First().Should().BeOfType(); + var assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); + assistantMessage.Content.Should().Be("How can I help you?"); + + // assistant text message with single tool call + message = new ToolCallMessage("test", "test", "assistant"); + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + + oaiMessages.Count().Should().Be(1); + oaiMessages.First().Should().BeOfType(); + assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); + assistantMessage.Content.Should().BeNullOrEmpty(); + assistantMessage.ToolCalls.Count().Should().Be(1); + assistantMessage.ToolCalls.First().Should().BeOfType(); + + // user should not suppose to send tool call message + message = new ToolCallMessage("test", "test", "user"); + Func action = () => agent.ToOpenAIChatRequestMessage(message).First(); + action.Should().Throw().WithMessage("ToolCallMessage is not supported when message.From is not the same with agent"); + + // assistant text message with multiple tool calls + message = new ToolCallMessage( + toolCalls: + [ + new ToolCall("test", "test"), + new ToolCall("test", "test"), + ], "assistant"); + + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + + oaiMessages.Count().Should().Be(1); + oaiMessages.First().Should().BeOfType(); + assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); + assistantMessage.Content.Should().BeNullOrEmpty(); + assistantMessage.ToolCalls.Count().Should().Be(2); + + // tool call result message + message = new ToolCallResultMessage("result", "test", "test", "user"); + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + + oaiMessages.Count().Should().Be(1); + oaiMessages.First().Should().BeOfType(); + var toolCallMessage = (ChatRequestToolMessage)oaiMessages.First(); + toolCallMessage.Content.Should().Be("result"); + + // tool call result message with multiple tool calls + message = new ToolCallResultMessage( + toolCalls: + [ + new ToolCall("result", "test", "test"), + new ToolCall("result", "test", "test"), + ], "user"); + + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + + oaiMessages.Count().Should().Be(2); + oaiMessages.First().Should().BeOfType(); + toolCallMessage = (ChatRequestToolMessage)oaiMessages.First(); + toolCallMessage.Content.Should().Be("test"); + oaiMessages.Last().Should().BeOfType(); + toolCallMessage = (ChatRequestToolMessage)oaiMessages.Last(); + toolCallMessage.Content.Should().Be("test"); + + // aggregate message test + // aggregate message with tool call and tool call result will be returned by GPT agent if the tool call is automatically invoked inside agent + message = new AggregateMessage( + message1: new ToolCallMessage("test", "test", "assistant"), + message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"); + + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + + oaiMessages.Count().Should().Be(2); + oaiMessages.First().Should().BeOfType(); + assistantMessage = (ChatRequestAssistantMessage)oaiMessages.First(); + assistantMessage.Content.Should().BeNullOrEmpty(); + assistantMessage.ToolCalls.Count().Should().Be(1); + + oaiMessages.Last().Should().BeOfType(); + toolCallMessage = (ChatRequestToolMessage)oaiMessages.Last(); + toolCallMessage.Content.Should().Be("result"); + + // aggregate message test 2 + // if the aggregate message is from user, it should be converted to user message + message = new AggregateMessage( + message1: new ToolCallMessage("test", "test", "user"), + message2: new ToolCallResultMessage("result", "test", "test", "user"), "user"); + + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + + oaiMessages.Count().Should().Be(1); + oaiMessages.First().Should().BeOfType(); + userMessage = (ChatRequestUserMessage)oaiMessages.First(); + userMessage.Content.Should().Be("result"); + + // aggregate message test 3 + // if the aggregate message is from user and contains multiple tool call results, it should be converted to user message + message = new AggregateMessage( + message1: new ToolCallMessage( + toolCalls: + [ + new ToolCall("test", "test"), + new ToolCall("test", "test"), + ], from: "user"), + message2: new ToolCallResultMessage( + toolCalls: + [ + new ToolCall("result", "test", "test"), + new ToolCall("result", "test", "test"), + ], from: "user"), "user"); + + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages.Count().Should().Be(2); + oaiMessages.First().Should().BeOfType(); + oaiMessages.Last().Should().BeOfType(); + + // system message + message = new TextMessage(Role.System, "You are a helpful AI assistant"); + oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages.Count().Should().Be(1); + oaiMessages.First().Should().BeOfType(); + } + + [Fact] + public void ToOpenAIChatRequestMessageShortCircuitTest() + { + var agent = new EchoAgent("assistant"); + + ChatRequestMessage[] messages = + [ + new ChatRequestUserMessage("Hello"), + new ChatRequestAssistantMessage("How can I help you?"), + new ChatRequestSystemMessage("You are a helpful AI assistant"), + new ChatRequestFunctionMessage("result", "functionName"), + new ChatRequestToolMessage("test", "test"), + ]; + + foreach (var oaiMessage in messages) + { + IMessage message = new MessageEnvelope(oaiMessage); + var oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages.Count().Should().Be(1); + oaiMessages.First().Should().Be(oaiMessage); + } + } private void VerifyOAIMessages(IEnumerable<(IMessage, IEnumerable)> messages) { var jsonObjects = messages.Select(pair => diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index 361855150b3..a66c1cbed16 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; using AutoGen.OpenAI; using Azure.AI.OpenAI; @@ -245,9 +246,9 @@ await foreach (var reply in replyStream) if (finalReply is ToolCallResultMessage toolCallResultMessage) { - toolCallResultMessage.Result.Should().Be(answer); + toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer); toolCallResultMessage.From.Should().Be(agent.Name); - toolCallResultMessage.ToolCallMessage.FunctionName.Should().Be(nameof(EchoAsync)); + toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync)); } else { From 8919058b2747f58dc4f78e8fff42fe235178f0ed Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 13 Feb 2024 22:40:08 -0800 Subject: [PATCH 04/27] add more test --- dotnet/src/AutoGen/Core/Agent/IAgent.cs | 46 +++++++++++++ .../AutoGen/Core/Message/AggregateMessage.cs | 53 +++++++++++++++ dotnet/src/AutoGen/Core/Message/IMessage.cs | 14 ++++ .../src/AutoGen/Core/Message/ImageMessage.cs | 25 +++++++ dotnet/src/AutoGen/Core/Message/Message.cs | 54 +++++++++++++++ .../AutoGen/Core/Message/MessageEnvelope.cs | 22 +++++++ .../AutoGen/Core/Message/MultiModalMessage.cs | 57 ++++++++++++++++ dotnet/src/AutoGen/Core/Message/Role.cs | 54 +++++++++++++++ .../src/AutoGen/Core/Message/TextMessage.cs | 25 +++++++ .../AutoGen/Core/Message/ToolCallMessage.cs | 65 +++++++++++++++++++ .../Core/Message/ToolCallResultMessage.cs | 55 ++++++++++++++++ 11 files changed, 470 insertions(+) create mode 100644 dotnet/src/AutoGen/Core/Agent/IAgent.cs create mode 100644 dotnet/src/AutoGen/Core/Message/AggregateMessage.cs create mode 100644 dotnet/src/AutoGen/Core/Message/IMessage.cs create mode 100644 dotnet/src/AutoGen/Core/Message/ImageMessage.cs create mode 100644 dotnet/src/AutoGen/Core/Message/Message.cs create mode 100644 dotnet/src/AutoGen/Core/Message/MessageEnvelope.cs create mode 100644 dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs create mode 100644 dotnet/src/AutoGen/Core/Message/Role.cs create mode 100644 dotnet/src/AutoGen/Core/Message/TextMessage.cs create mode 100644 dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs create mode 100644 dotnet/src/AutoGen/Core/Message/ToolCallResultMessage.cs diff --git a/dotnet/src/AutoGen/Core/Agent/IAgent.cs b/dotnet/src/AutoGen/Core/Agent/IAgent.cs new file mode 100644 index 00000000000..5bf348c8660 --- /dev/null +++ b/dotnet/src/AutoGen/Core/Agent/IAgent.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IAgent.cs + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Azure.AI.OpenAI; + +namespace AutoGen; + +public interface IAgent +{ + public string Name { get; } + + /// + /// Generate reply + /// + /// conversation history + /// completion option. If provided, it should override existing option if there's any + public Task GenerateReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default); +} + +/// +/// agent that supports streaming reply +/// +public interface IStreamingReplyAgent : IAgent +{ + public Task> GenerateReplyStreamingAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default); +} + +public class GenerateReplyOptions +{ + public float? Temperature { get; set; } + + public int? MaxToken { get; set; } + + public string[]? StopSequence { get; set; } + + public FunctionDefinition[]? Functions { get; set; } +} diff --git a/dotnet/src/AutoGen/Core/Message/AggregateMessage.cs b/dotnet/src/AutoGen/Core/Message/AggregateMessage.cs new file mode 100644 index 00000000000..f15cda1b335 --- /dev/null +++ b/dotnet/src/AutoGen/Core/Message/AggregateMessage.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Message.cs + +using System; +using System.Collections.Generic; + +namespace AutoGen; + +public class AggregateMessage : IMessage + where TMessage1 : IMessage + where TMessage2 : IMessage +{ + public AggregateMessage(TMessage1 message1, TMessage2 message2, string? from = null) + { + this.From = from; + this.Message1 = message1; + this.Message2 = message2; + this.Validate(); + } + + public TMessage1 Message1 { get; } + + public TMessage2 Message2 { get; } + + public string? From { get; set; } + + private void Validate() + { + var messages = new List { this.Message1, this.Message2 }; + // the from property of all messages should be the same with the from property of the aggregate message + + foreach (var message in messages) + { + if (message.From != this.From) + { + throw new ArgumentException($"The from property of the message {message} is different from the from property of the aggregate message {this}"); + } + } + } + + public override string ToString() + { + var stringBuilder = new System.Text.StringBuilder(); + var messages = new List { this.Message1, this.Message2 }; + stringBuilder.Append($"AggregateMessage({this.From})"); + foreach (var message in messages) + { + stringBuilder.Append($"\n\t{message}"); + } + + return stringBuilder.ToString(); + } +} diff --git a/dotnet/src/AutoGen/Core/Message/IMessage.cs b/dotnet/src/AutoGen/Core/Message/IMessage.cs new file mode 100644 index 00000000000..1c1d7499a1d --- /dev/null +++ b/dotnet/src/AutoGen/Core/Message/IMessage.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Message.cs + +namespace AutoGen; + +public interface IMessage +{ + string? From { get; set; } +} + +public interface IMessage : IMessage +{ + T Content { get; } +} diff --git a/dotnet/src/AutoGen/Core/Message/ImageMessage.cs b/dotnet/src/AutoGen/Core/Message/ImageMessage.cs new file mode 100644 index 00000000000..4c94b8a63f0 --- /dev/null +++ b/dotnet/src/AutoGen/Core/Message/ImageMessage.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ImageMessage.cs + +namespace AutoGen; + +public class ImageMessage : IMessage +{ + public ImageMessage(Role role, string url, string? from = null) + { + this.Role = role; + this.From = from; + this.Url = url; + } + + public Role Role { get; set; } + + public string Url { get; set; } + + public string? From { get; set; } + + public override string ToString() + { + return $"ImageMessage({this.Role}, {this.Url}, {this.From})"; + } +} diff --git a/dotnet/src/AutoGen/Core/Message/Message.cs b/dotnet/src/AutoGen/Core/Message/Message.cs new file mode 100644 index 00000000000..f84b296baf9 --- /dev/null +++ b/dotnet/src/AutoGen/Core/Message/Message.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Message.cs + +using System.Collections.Generic; +using Azure.AI.OpenAI; + +namespace AutoGen; + +public class Message : IMessage +{ + public Message( + Role role, + string? content, + string? from = null, + FunctionCall? functionCall = null) + { + this.Role = role; + this.Content = content; + this.From = from; + this.FunctionName = functionCall?.Name; + this.FunctionArguments = functionCall?.Arguments; + } + + public Message(Message other) + : this(other.Role, other.Content, other.From) + { + this.FunctionName = other.FunctionName; + this.FunctionArguments = other.FunctionArguments; + this.Value = other.Value; + this.Metadata = other.Metadata; + } + + public Role Role { get; set; } + + public string? Content { get; set; } + + public string? From { get; set; } + + public string? FunctionName { get; set; } + + public string? FunctionArguments { get; set; } + + /// + /// raw message + /// + public object? Value { get; set; } + + public IList> Metadata { get; set; } = new List>(); + + public override string ToString() + { + return $"Message({this.Role}, {this.Content}, {this.From}, {this.FunctionName}, {this.FunctionArguments})"; + } +} diff --git a/dotnet/src/AutoGen/Core/Message/MessageEnvelope.cs b/dotnet/src/AutoGen/Core/Message/MessageEnvelope.cs new file mode 100644 index 00000000000..b3bd85dd50d --- /dev/null +++ b/dotnet/src/AutoGen/Core/Message/MessageEnvelope.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// MessageEnvelope.cs + +using System.Collections.Generic; + +namespace AutoGen; + +public class MessageEnvelope : IMessage +{ + public MessageEnvelope(T content, string? from = null, IDictionary? metadata = null) + { + this.Content = content; + this.From = from; + this.Metadata = metadata ?? new Dictionary(); + } + + public T Content { get; } + + public string? From { get; set; } + + public IDictionary Metadata { get; set; } +} diff --git a/dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs b/dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs new file mode 100644 index 00000000000..9649387687a --- /dev/null +++ b/dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Message.cs + +using System; +using System.Collections.Generic; + +namespace AutoGen; + +public class MultiModalMessage : IMessage +{ + public MultiModalMessage(IEnumerable content, string? from = null) + { + this.Content = content; + this.From = from; + this.Validate(); + } + + public Role Role { get; set; } + + public IEnumerable Content { get; set; } + + public string? From { get; set; } + + private void Validate() + { + foreach (var message in this.Content) + { + if (message.From != this.From) + { + var reason = $"The from property of the message {message} is different from the from property of the aggregate message {this}"; + throw new ArgumentException($"Invalid aggregate message {reason}"); + } + } + + // all message must be either text or image + foreach (var message in this.Content) + { + if (message is not TextMessage && message is not ImageMessage) + { + var reason = $"The message {message} is not a text or image message"; + throw new ArgumentException($"Invalid aggregate message {reason}"); + } + } + } + + public override string ToString() + { + var stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append($"MultiModalMessage({this.Role}, {this.From})"); + foreach (var message in this.Content) + { + stringBuilder.Append($"\n\t{message}"); + } + + return stringBuilder.ToString(); + } +} diff --git a/dotnet/src/AutoGen/Core/Message/Role.cs b/dotnet/src/AutoGen/Core/Message/Role.cs new file mode 100644 index 00000000000..4be88007ae9 --- /dev/null +++ b/dotnet/src/AutoGen/Core/Message/Role.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Role.cs + +using System; + +namespace AutoGen; + +public readonly struct Role : IEquatable +{ + private readonly string label; + + internal Role(string name) + { + label = name; + } + + public static Role User { get; } = new Role("user"); + + public static Role Assistant { get; } = new Role("assistant"); + + public static Role System { get; } = new Role("system"); + + public static Role Function { get; } = new Role("function"); + + public bool Equals(Role other) + { + return label.Equals(other.label, StringComparison.OrdinalIgnoreCase); + } + + public override string ToString() + { + return label; + } + + public override bool Equals(object? obj) + { + return obj is Role other && Equals(other); + } + + public override int GetHashCode() + { + return label.GetHashCode(); + } + + public static bool operator ==(Role left, Role right) + { + return left.Equals(right); + } + + public static bool operator !=(Role left, Role right) + { + return !(left == right); + } +} diff --git a/dotnet/src/AutoGen/Core/Message/TextMessage.cs b/dotnet/src/AutoGen/Core/Message/TextMessage.cs new file mode 100644 index 00000000000..288234830a9 --- /dev/null +++ b/dotnet/src/AutoGen/Core/Message/TextMessage.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TextMessage.cs + +namespace AutoGen; + +public class TextMessage : IMessage +{ + public TextMessage(Role role, string content, string? from = null) + { + this.Content = content; + this.Role = role; + this.From = from; + } + + public Role Role { get; set; } + + public string Content { get; } + + public string? From { get; set; } + + public override string ToString() + { + return $"TextMessage({this.Role}, {this.Content}, {this.From})"; + } +} diff --git a/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs b/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs new file mode 100644 index 00000000000..80efae58a19 --- /dev/null +++ b/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Message.cs + +using System.Collections.Generic; +using System.Text; + +namespace AutoGen; + +public class ToolCall +{ + public ToolCall(string functionName, string functionArgs) + { + this.FunctionName = functionName; + this.FunctionArguments = functionArgs; + } + + public ToolCall(string functionName, string functionArgs, string result) + { + this.FunctionName = functionName; + this.FunctionArguments = functionArgs; + this.Result = result; + } + + public string FunctionName { get; set; } + + public string FunctionArguments { get; set; } + + public string? Result { get; set; } + + public override string ToString() + { + return $"ToolCall({this.FunctionName}, {this.FunctionArguments}, {this.Result})"; + } +} + +public class ToolCallMessage : IMessage +{ + public ToolCallMessage(IEnumerable toolCalls, string? from = null) + { + this.From = from; + this.ToolCalls = toolCalls; + } + + public ToolCallMessage(string functionName, string functionArgs, string? from = null) + { + this.From = from; + this.ToolCalls = new List { new ToolCall(functionName, functionArgs) }; + } + + public IEnumerable ToolCalls { get; set; } + + public string? From { get; set; } + + public override string ToString() + { + var sb = new StringBuilder(); + sb.Append($"ToolCallMessage({this.From})"); + foreach (var toolCall in this.ToolCalls) + { + sb.Append($"\n\t{toolCall}"); + } + + return sb.ToString(); + } +} diff --git a/dotnet/src/AutoGen/Core/Message/ToolCallResultMessage.cs b/dotnet/src/AutoGen/Core/Message/ToolCallResultMessage.cs new file mode 100644 index 00000000000..0b4fe3d4523 --- /dev/null +++ b/dotnet/src/AutoGen/Core/Message/ToolCallResultMessage.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ToolCallResultMessage.cs + +using System.Collections.Generic; +using System.Text; + +namespace AutoGen; + +public class ToolCallResultMessage : IMessage +{ + public ToolCallResultMessage(IEnumerable toolCalls, string? from = null) + { + this.From = from; + this.ToolCalls = toolCalls; + } + + public ToolCallResultMessage(string result, string functionName, string functionArgs, string? from = null) + { + this.From = from; + var toolCall = new ToolCall(functionName, functionArgs); + toolCall.Result = result; + this.ToolCalls = [toolCall]; + } + + /// + /// The original tool call message + /// + public IEnumerable ToolCalls { get; set; } + + public string? From { get; set; } + + public override string ToString() + { + var sb = new StringBuilder(); + sb.Append($"ToolCallResultMessage({this.From})"); + foreach (var toolCall in this.ToolCalls) + { + sb.Append($"\n\t{toolCall}"); + } + + return sb.ToString(); + } + + private void Validate() + { + // each tool call must have a result + foreach (var toolCall in this.ToolCalls) + { + if (string.IsNullOrEmpty(toolCall.Result)) + { + throw new System.ArgumentException($"The tool call {toolCall} does not have a result"); + } + } + } +} From f99659e1dbf5fc5bbc694942b779e9076adf80e7 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 14 Feb 2024 00:01:49 -0800 Subject: [PATCH 05/27] fix build error --- dotnet/src/AutoGen/Core/Agent/IAgent.cs | 4 +-- dotnet/src/AutoGen/Core/IAgent.cs | 46 ------------------------- 2 files changed, 2 insertions(+), 48 deletions(-) delete mode 100644 dotnet/src/AutoGen/Core/IAgent.cs diff --git a/dotnet/src/AutoGen/Core/Agent/IAgent.cs b/dotnet/src/AutoGen/Core/Agent/IAgent.cs index 5bf348c8660..3e4ab709236 100644 --- a/dotnet/src/AutoGen/Core/Agent/IAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/IAgent.cs @@ -17,8 +17,8 @@ public interface IAgent /// /// conversation history /// completion option. If provided, it should override existing option if there's any - public Task GenerateReplyAsync( - IEnumerable messages, + public Task GenerateReplyAsync( + IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/AutoGen/Core/IAgent.cs b/dotnet/src/AutoGen/Core/IAgent.cs deleted file mode 100644 index 3e4ab709236..00000000000 --- a/dotnet/src/AutoGen/Core/IAgent.cs +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// IAgent.cs - -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using Azure.AI.OpenAI; - -namespace AutoGen; - -public interface IAgent -{ - public string Name { get; } - - /// - /// Generate reply - /// - /// conversation history - /// completion option. If provided, it should override existing option if there's any - public Task GenerateReplyAsync( - IEnumerable messages, - GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default); -} - -/// -/// agent that supports streaming reply -/// -public interface IStreamingReplyAgent : IAgent -{ - public Task> GenerateReplyStreamingAsync( - IEnumerable messages, - GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default); -} - -public class GenerateReplyOptions -{ - public float? Temperature { get; set; } - - public int? MaxToken { get; set; } - - public string[]? StopSequence { get; set; } - - public FunctionDefinition[]? Functions { get; set; } -} From 1e3364b81243c35304c8321ffa711037c5894e66 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 14 Feb 2024 11:07:19 -0800 Subject: [PATCH 06/27] rename header --- dotnet/src/AutoGen/Core/Message/AggregateMessage.cs | 2 +- dotnet/src/AutoGen/Core/Message/IMessage.cs | 2 +- dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs | 2 +- dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dotnet/src/AutoGen/Core/Message/AggregateMessage.cs b/dotnet/src/AutoGen/Core/Message/AggregateMessage.cs index f15cda1b335..a375f7c2b38 100644 --- a/dotnet/src/AutoGen/Core/Message/AggregateMessage.cs +++ b/dotnet/src/AutoGen/Core/Message/AggregateMessage.cs @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Message.cs +// AggregateMessage.cs using System; using System.Collections.Generic; diff --git a/dotnet/src/AutoGen/Core/Message/IMessage.cs b/dotnet/src/AutoGen/Core/Message/IMessage.cs index 1c1d7499a1d..9ade6c1ab66 100644 --- a/dotnet/src/AutoGen/Core/Message/IMessage.cs +++ b/dotnet/src/AutoGen/Core/Message/IMessage.cs @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Message.cs +// IMessage.cs namespace AutoGen; diff --git a/dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs b/dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs index 9649387687a..cfbfce677a5 100644 --- a/dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs +++ b/dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Message.cs +// MultiModalMessage.cs using System; using System.Collections.Generic; diff --git a/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs b/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs index 80efae58a19..81656c014ee 100644 --- a/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs +++ b/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Message.cs +// ToolCallMessage.cs using System.Collections.Generic; using System.Text; From 01e492ecb7156d7b7e725977d01490e51901a5b3 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 14 Feb 2024 12:16:33 -0800 Subject: [PATCH 07/27] add semantic kernel project --- dotnet/AutoGen.sln | 9 +- dotnet/eng/Version.props | 6 +- .../AutoGen.BasicSample.csproj | 1 + .../Example09_SemanticKernel.cs | 29 ++ dotnet/sample/AutoGen.BasicSamples/Program.cs | 2 +- .../AutoGen.SemanticKernel.csproj | 27 ++ .../Extension/KernelExtension.cs | 14 + .../SemanticKernelAgent.cs | 271 ++++++++++++++++++ .../src/AutoGen/Core/Agent/UserProxyAgent.cs | 2 +- dotnet/src/AutoGen/OpenAI/GPTAgent.cs | 5 +- 10 files changed, 359 insertions(+), 7 deletions(-) create mode 100644 dotnet/sample/AutoGen.BasicSamples/Example09_SemanticKernel.cs create mode 100644 dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj create mode 100644 dotnet/src/AutoGen.SemanticKernel/Extension/KernelExtension.cs create mode 100644 dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index b4697e5932a..bf3fee6cd93 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -21,7 +21,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "sample", "sample", "{FBFEAD EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.DotnetInteractive", "src\AutoGen.DotnetInteractive\AutoGen.DotnetInteractive.csproj", "{B61D8008-7FB7-4C0E-8044-3A74AA63A596}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.LMStudio", "src\AutoGen.LMStudio\AutoGen.LMStudio.csproj", "{F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.LMStudio", "src\AutoGen.LMStudio\AutoGen.LMStudio.csproj", "{F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.SemanticKernel", "src\AutoGen.SemanticKernel\AutoGen.SemanticKernel.csproj", "{45D6FC80-36F3-4967-9663-E20B63824621}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -57,6 +59,10 @@ Global {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}.Debug|Any CPU.Build.0 = Debug|Any CPU {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}.Release|Any CPU.ActiveCfg = Release|Any CPU {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}.Release|Any CPU.Build.0 = Release|Any CPU + {45D6FC80-36F3-4967-9663-E20B63824621}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {45D6FC80-36F3-4967-9663-E20B63824621}.Debug|Any CPU.Build.0 = Debug|Any CPU + {45D6FC80-36F3-4967-9663-E20B63824621}.Release|Any CPU.ActiveCfg = Release|Any CPU + {45D6FC80-36F3-4967-9663-E20B63824621}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -69,6 +75,7 @@ Global {7EBF916A-A7B1-4B74-AF10-D705B7A18F58} = {FBFEAD1F-29EB-4D99-A672-0CD8473E10B9} {B61D8008-7FB7-4C0E-8044-3A74AA63A596} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {45D6FC80-36F3-4967-9663-E20B63824621} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B} diff --git a/dotnet/eng/Version.props b/dotnet/eng/Version.props index fc670ed2d27..a7d8a65ed44 100644 --- a/dotnet/eng/Version.props +++ b/dotnet/eng/Version.props @@ -1,11 +1,11 @@ - 1.0.0-beta.12 - 1.0.1 + 1.0.0-beta.13 + 1.4.0 5.0.0 4.3.0 - 5.2.4 + 6.0.0 6.8.0 2.4.2 17.7.0 diff --git a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj index ae9f3fdc057..f28be25b034 100644 --- a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj +++ b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj @@ -12,6 +12,7 @@ + diff --git a/dotnet/sample/AutoGen.BasicSamples/Example09_SemanticKernel.cs b/dotnet/sample/AutoGen.BasicSamples/Example09_SemanticKernel.cs new file mode 100644 index 00000000000..3a1486dc25f --- /dev/null +++ b/dotnet/sample/AutoGen.BasicSamples/Example09_SemanticKernel.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Example09_SemanticKernel.cs + +using Microsoft.SemanticKernel; +using AutoGen.SemanticKernel.Extension; +namespace AutoGen.BasicSample; + +public class Example09_SemanticKernel +{ + public static async Task RunAsync() + { + var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable."); + var modelId = "gpt-3.5-turbo"; + var kernel = Kernel.CreateBuilder() + .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey) + .Build(); + + var skAgent = kernel.ToSemanticKernelAgent(name: "skAgent", systemMessage: "You are a helpful AI assistant") + .RegisterPrintFormatMessageHook(); + + var userProxyAgent = new UserProxyAgent(name: "user", humanInputMode: ConversableAgent.HumanInputMode.ALWAYS); + + await userProxyAgent.InitiateChatAsync( + receiver: skAgent, + message: "Hey assistant, please help me to do some tasks.", + maxRound: 10); + } + +} diff --git a/dotnet/sample/AutoGen.BasicSamples/Program.cs b/dotnet/sample/AutoGen.BasicSamples/Program.cs index 4b7d16f69ef..62ce3310ce1 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Program.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Program.cs @@ -3,4 +3,4 @@ using AutoGen.BasicSample; -await Example08_LMStudio.RunAsync(); +await Example09_SemanticKernel.RunAsync(); diff --git a/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj b/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj new file mode 100644 index 00000000000..36e9325663a --- /dev/null +++ b/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj @@ -0,0 +1,27 @@ + + + + netstandard2.0 + AutoGen.SemanticKernel + + + + + + + AutoGen.SemanticKernel + + This package contains the semantic kernel integration for AutoGen + + + + + + + + + + + + + diff --git a/dotnet/src/AutoGen.SemanticKernel/Extension/KernelExtension.cs b/dotnet/src/AutoGen.SemanticKernel/Extension/KernelExtension.cs new file mode 100644 index 00000000000..063b6f1f989 --- /dev/null +++ b/dotnet/src/AutoGen.SemanticKernel/Extension/KernelExtension.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// KernelExtension.cs + +using Microsoft.SemanticKernel; + +namespace AutoGen.SemanticKernel.Extension; + +public static class KernelExtension +{ + public static SemanticKernelAgent ToSemanticKernelAgent(this Kernel kernel, string name, string systemMessage = "You are a helpful AI assistant") + { + return new SemanticKernelAgent(kernel, name, systemMessage); + } +} diff --git a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs new file mode 100644 index 00000000000..2500bcce5f7 --- /dev/null +++ b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs @@ -0,0 +1,271 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SemanticKernelAgent.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure.AI.OpenAI; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.OpenAI; + +namespace AutoGen.SemanticKernel; + +/// +/// The agent that intergrade with the semantic kernel. +/// +public class SemanticKernelAgent : IStreamingReplyAgent +{ + private readonly Kernel _kernel; + private readonly string _systemMessage; + public SemanticKernelAgent( + Kernel kernel, + string name, + string systemMessage = "You are a helpful AI assistant") + { + _kernel = kernel; + this.Name = name; + _systemMessage = systemMessage; + } + public string Name { get; } + + + public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + { + var chatMessageContents = ProcessMessage(messages); + // if there's no system message in chatMessageContents, add one to the beginning + if (!chatMessageContents.Any(c => c.Role == AuthorRole.System)) + { + chatMessageContents = new[] { new ChatMessageContent(AuthorRole.System, _systemMessage) }.Concat(chatMessageContents); + } + + var chatHistory = new ChatHistory(chatMessageContents); + var option = new OpenAIPromptExecutionSettings + { + Temperature = options?.Temperature ?? 0.7f, + MaxTokens = options?.MaxToken ?? 1024, + StopSequences = options?.StopSequence, + }; + + var chatService = _kernel.GetRequiredService(); + + var reply = await chatService.GetChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken); + + if (reply.Count() == 1) + { + // might be a plain text return or a function call return + var msg = reply.First(); + if (msg is OpenAIChatMessageContent oaiContent) + { + if (oaiContent.Content is string content) + { + return new Message(Role.Assistant, content, this.Name); + } + else if (oaiContent.ToolCalls is { Count: 1 } && oaiContent.ToolCalls.First() is ChatCompletionsFunctionToolCall toolCall) + { + return new Message(Role.Assistant, content: null, this.Name) + { + FunctionName = toolCall.Name, + FunctionArguments = toolCall.Arguments, + }; + } + else + { + // parallel function call is not supported + throw new InvalidOperationException("Unsupported return type, only plain text and function call are supported."); + } + } + else + { + throw new InvalidOperationException("Unsupported return type"); + } + } + else + { + throw new InvalidOperationException("Unsupported return type, multiple messages are not supported."); + } + } + + public async Task> GenerateReplyStreamingAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default) + { + var chatMessageContents = ProcessMessage(messages); + // if there's no system message in chatMessageContents, add one to the beginning + if (!chatMessageContents.Any(c => c.Role == AuthorRole.System)) + { + chatMessageContents = new[] { new ChatMessageContent(AuthorRole.System, _systemMessage) }.Concat(chatMessageContents); + } + + var chatHistory = new ChatHistory(chatMessageContents); + var option = new OpenAIPromptExecutionSettings + { + Temperature = options?.Temperature ?? 0.7f, + MaxTokens = options?.MaxToken ?? 1024, + StopSequences = options?.StopSequence, + }; + + var chatService = _kernel.GetRequiredService(); + var response = chatService.GetStreamingChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken); + + return ProcessMessage(response); + } + + private async IAsyncEnumerable ProcessMessage(IAsyncEnumerable response) + { + string? text = null; + await foreach (var content in response) + { + if (content is OpenAIStreamingChatMessageContent oaiStreamingChatContent && oaiStreamingChatContent.Content is string chunk) + { + text += chunk; + yield return new Message(Role.Assistant, text, this.Name); + } + else + { + throw new InvalidOperationException("Unsupported return type"); + } + } + + if (text is not null) + { + yield return new Message(Role.Assistant, text, this.Name); + } + } + + private IEnumerable ProcessMessage(IEnumerable messages) + { + return messages.SelectMany(m => + { + if (m is IMessage chatMessageContent) + { + return [chatMessageContent.Content]; + } + if (m.From == this.Name) + { + return ProcessMessageForSelf(m); + } + else + { + return ProcessMessageForOthers(m); + } + }); + } + + private IEnumerable ProcessMessageForSelf(IMessage message) + { + return message switch + { + TextMessage textMessage => ProcessMessageForSelf(textMessage), + MultiModalMessage multiModalMessage => ProcessMessageForSelf(multiModalMessage), + Message m => ProcessMessageForSelf(m), + _ => throw new System.NotImplementedException(), + }; + } + + private IEnumerable ProcessMessageForOthers(IMessage message) + { + return message switch + { + TextMessage textMessage => ProcessMessageForOthers(textMessage), + MultiModalMessage multiModalMessage => ProcessMessageForOthers(multiModalMessage), + Message m => ProcessMessageForOthers(m), + _ => throw new System.NotImplementedException(), + }; + } + + private IEnumerable ProcessMessageForSelf(TextMessage message) + { + if (message.Role == Role.System) + { + return [new ChatMessageContent(AuthorRole.System, message.Content)]; + } + else + { + return [new ChatMessageContent(AuthorRole.Assistant, message.Content)]; + } + } + + + private IEnumerable ProcessMessageForOthers(TextMessage message) + { + if (message.Role == Role.System) + { + return [new ChatMessageContent(AuthorRole.System, message.Content)]; + } + else + { + return [new ChatMessageContent(AuthorRole.User, message.Content)]; + } + } + + private IEnumerable ProcessMessageForSelf(MultiModalMessage message) + { + throw new System.InvalidOperationException("MultiModalMessage is not supported in the semantic kernel if it's from self."); + } + + private IEnumerable ProcessMessageForOthers(MultiModalMessage message) + { + var collections = new ChatMessageContentItemCollection(); + foreach (var item in message.Content) + { + if (item is TextMessage textContent) + { + collections.Add(new TextContent(textContent.Content)); + } + else if (item is ImageMessage imageContent) + { + collections.Add(new ImageContent(new Uri(imageContent.Url))); + } + else + { + throw new InvalidOperationException($"Unsupported message type: {item.GetType().Name}"); + } + } + return [new ChatMessageContent(AuthorRole.User, collections)]; + } + + + private IEnumerable ProcessMessageForSelf(Message message) + { + if (message.Role == Role.System) + { + return [new ChatMessageContent(AuthorRole.System, message.Content)]; + } + else if (message.Content is string && message.FunctionName is null && message.FunctionArguments is null) + { + return [new ChatMessageContent(AuthorRole.Assistant, message.Content)]; + } + else if (message.Content is null && message.FunctionName is not null && message.FunctionArguments is not null) + { + throw new System.InvalidOperationException("Function call is not supported in the semantic kernel if it's from self."); + } + else + { + throw new System.InvalidOperationException("Unsupported message type"); + } + } + + private IEnumerable ProcessMessageForOthers(Message message) + { + if (message.Role == Role.System) + { + return [new ChatMessageContent(AuthorRole.System, message.Content)]; + } + else if (message.Content is string && message.FunctionName is null && message.FunctionArguments is null) + { + return [new ChatMessageContent(AuthorRole.User, message.Content)]; + } + else if (message.Content is null && message.FunctionName is not null && message.FunctionArguments is not null) + { + throw new System.InvalidOperationException("Function call is not supported in the semantic kernel if it's from others."); + } + else + { + throw new System.InvalidOperationException("Unsupported message type"); + } + } +} diff --git a/dotnet/src/AutoGen/Core/Agent/UserProxyAgent.cs b/dotnet/src/AutoGen/Core/Agent/UserProxyAgent.cs index 7d298c862a1..5bcee19c1fd 100644 --- a/dotnet/src/AutoGen/Core/Agent/UserProxyAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/UserProxyAgent.cs @@ -15,7 +15,7 @@ public class UserProxyAgent : ConversableAgent string systemMessage = "You are a helpful AI assistant", ConversableAgentConfig? llmConfig = null, Func, CancellationToken, Task>? isTermination = null, - HumanInputMode humanInputMode = HumanInputMode.NEVER, + HumanInputMode humanInputMode = HumanInputMode.ALWAYS, IDictionary>>? functionMap = null, string? defaultReply = null) : base(name: name, diff --git a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs index 0d4a37322cd..10d68c8aa4a 100644 --- a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs +++ b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs @@ -181,7 +181,10 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions var functions = options?.Functions ?? _functions; if (functions is not null && functions.Count() > 0) { - settings.Functions = functions.ToList(); + foreach (var f in functions) + { + settings.Functions.Add(f); + } //foreach (var f in functions) //{ // settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); From bccfbc5804f4e202ec638ee8ef0382ab80ee7b3a Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 14 Feb 2024 15:01:19 -0800 Subject: [PATCH 08/27] update sk example --- dotnet/eng/Version.props | 1 + .../AutoGen.BasicSample.csproj | 3 +- .../Example09_SemanticKernel.cs | 46 ++++++++++++++++--- .../Extension/KernelExtension.cs | 4 +- .../SemanticKernelAgent.cs | 10 ++-- 5 files changed, 51 insertions(+), 13 deletions(-) diff --git a/dotnet/eng/Version.props b/dotnet/eng/Version.props index a7d8a65ed44..3930c55d068 100644 --- a/dotnet/eng/Version.props +++ b/dotnet/eng/Version.props @@ -3,6 +3,7 @@ 1.0.0-beta.13 1.4.0 + 1.4.0-alpha 5.0.0 4.3.0 6.0.0 diff --git a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj index f28be25b034..75f5c546b29 100644 --- a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj +++ b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj @@ -6,7 +6,7 @@ enable enable True - $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219 + $(NoWarn);CS8981;CS8600;CS8602;CS8604;CS8618;CS0219;SKEXP0054 @@ -16,5 +16,6 @@ + diff --git a/dotnet/sample/AutoGen.BasicSamples/Example09_SemanticKernel.cs b/dotnet/sample/AutoGen.BasicSamples/Example09_SemanticKernel.cs index 3a1486dc25f..0344dc7cd02 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example09_SemanticKernel.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example09_SemanticKernel.cs @@ -1,27 +1,59 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Example09_SemanticKernel.cs -using Microsoft.SemanticKernel; +using System.ComponentModel; using AutoGen.SemanticKernel.Extension; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.OpenAI; namespace AutoGen.BasicSample; +public class LightPlugin +{ + public bool IsOn { get; set; } = false; + + [KernelFunction] + [Description("Gets the state of the light.")] + public string GetState() => this.IsOn ? "on" : "off"; + + [KernelFunction] + [Description("Changes the state of the light.'")] + public string ChangeState(bool newState) + { + this.IsOn = newState; + var state = this.GetState(); + + // Print the state to the console + Console.ForegroundColor = ConsoleColor.DarkBlue; + Console.WriteLine($"[Light is now {state}]"); + Console.ResetColor(); + + return state; + } +} + public class Example09_SemanticKernel { public static async Task RunAsync() { var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable."); var modelId = "gpt-3.5-turbo"; - var kernel = Kernel.CreateBuilder() - .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey) - .Build(); + var builder = Kernel.CreateBuilder() + .AddOpenAIChatCompletion(modelId: modelId, apiKey: openAIKey); + var kernel = builder.Build(); + var settings = new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions, + }; - var skAgent = kernel.ToSemanticKernelAgent(name: "skAgent", systemMessage: "You are a helpful AI assistant") + kernel.Plugins.AddFromObject(new LightPlugin()); + var assistantAgent = kernel + .ToSemanticKernelAgent(name: "assistant", systemMessage: "You control the light", settings) .RegisterPrintFormatMessageHook(); - var userProxyAgent = new UserProxyAgent(name: "user", humanInputMode: ConversableAgent.HumanInputMode.ALWAYS); + var userProxyAgent = new UserProxyAgent(name: "user", humanInputMode: ConversableAgent.HumanInputMode.ALWAYS); await userProxyAgent.InitiateChatAsync( - receiver: skAgent, + receiver: assistantAgent, message: "Hey assistant, please help me to do some tasks.", maxRound: 10); } diff --git a/dotnet/src/AutoGen.SemanticKernel/Extension/KernelExtension.cs b/dotnet/src/AutoGen.SemanticKernel/Extension/KernelExtension.cs index 063b6f1f989..f1589ab09e6 100644 --- a/dotnet/src/AutoGen.SemanticKernel/Extension/KernelExtension.cs +++ b/dotnet/src/AutoGen.SemanticKernel/Extension/KernelExtension.cs @@ -7,8 +7,8 @@ namespace AutoGen.SemanticKernel.Extension; public static class KernelExtension { - public static SemanticKernelAgent ToSemanticKernelAgent(this Kernel kernel, string name, string systemMessage = "You are a helpful AI assistant") + public static SemanticKernelAgent ToSemanticKernelAgent(this Kernel kernel, string name, string systemMessage = "You are a helpful AI assistant", PromptExecutionSettings? settings = null) { - return new SemanticKernelAgent(kernel, name, systemMessage); + return new SemanticKernelAgent(kernel, name, systemMessage, settings); } } diff --git a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs index 2500bcce5f7..370cb7740c7 100644 --- a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs +++ b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs @@ -20,14 +20,18 @@ public class SemanticKernelAgent : IStreamingReplyAgent { private readonly Kernel _kernel; private readonly string _systemMessage; + private readonly PromptExecutionSettings? _settings; + public SemanticKernelAgent( Kernel kernel, string name, - string systemMessage = "You are a helpful AI assistant") + string systemMessage = "You are a helpful AI assistant", + PromptExecutionSettings? settings = null) { _kernel = kernel; this.Name = name; _systemMessage = systemMessage; + _settings = settings; } public string Name { get; } @@ -42,7 +46,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, Gen } var chatHistory = new ChatHistory(chatMessageContents); - var option = new OpenAIPromptExecutionSettings + var option = _settings ?? new OpenAIPromptExecutionSettings { Temperature = options?.Temperature ?? 0.7f, MaxTokens = options?.MaxToken ?? 1024, @@ -101,7 +105,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, Gen } var chatHistory = new ChatHistory(chatMessageContents); - var option = new OpenAIPromptExecutionSettings + var option = _settings ?? new OpenAIPromptExecutionSettings { Temperature = options?.Temperature ?? 0.7f, MaxTokens = options?.MaxToken ?? 1024, From 1935f70ac17cfe2f2ab0eda7bb607a0b3b48fac1 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Wed, 14 Feb 2024 15:12:47 -0800 Subject: [PATCH 09/27] update dotnet version --- dotnet/global.json | 2 +- .../FilescopeNamespaceFunctionExample.cs | 2 +- dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dotnet/global.json b/dotnet/global.json index 62b2d730ed5..a93054a455c 100644 --- a/dotnet/global.json +++ b/dotnet/global.json @@ -1,6 +1,6 @@ { "sdk": { "version": "8.0.101", - "rollForward": "latestMajor" + "rollForward": "latestMinor" } } \ No newline at end of file diff --git a/dotnet/test/AutoGen.SourceGenerator.Tests/FilescopeNamespaceFunctionExample.cs b/dotnet/test/AutoGen.SourceGenerator.Tests/FilescopeNamespaceFunctionExample.cs index f09dd008c1a..1c5c9dd79d8 100644 --- a/dotnet/test/AutoGen.SourceGenerator.Tests/FilescopeNamespaceFunctionExample.cs +++ b/dotnet/test/AutoGen.SourceGenerator.Tests/FilescopeNamespaceFunctionExample.cs @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// TopLevelStatementFunctionExample.cs +// FilescopeNamespaceFunctionExample.cs namespace AutoGen.SourceGenerator.Tests; public partial class FilescopeNamespaceFunctionExample diff --git a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs index baa6507e412..564d55ce626 100644 --- a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs +++ b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs @@ -231,7 +231,7 @@ public void ToOpenAIChatRequestMessageTest() ], from: "user"), message2: new ToolCallResultMessage( toolCalls: - [ + [ new ToolCall("result", "test", "test"), new ToolCall("result", "test", "test"), ], from: "user"), "user"); From f0794050fee250c6392fb4ace0b1c7a214baae76 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 29 Feb 2024 15:19:25 -0800 Subject: [PATCH 10/27] use IMessage --- .../CodeSnippet/CreateAnAgent.cs | 11 +- .../CodeSnippet/FunctionCallCodeSnippet.cs | 13 +- .../CodeSnippet/MiddlewareAgentCodeSnippet.cs | 22 +- .../Example01_AssistantAgent.cs | 12 +- .../Example02_TwoAgent_MathChat.cs | 2 +- .../Example03_Agent_FunctionCall.cs | 13 +- .../Example05_Dalle_And_GPT4V.cs | 24 +- ...7_Dynamic_GroupChat_Calculate_Fibonacci.cs | 22 +- .../Example09_LMStudio_FunctionCall.cs | 2 +- .../Extension/AgentExtension.cs | 4 +- dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs | 4 +- .../SemanticKernelAgent.cs | 2 +- .../src/AutoGen/Core/Agent/AssistantAgent.cs | 2 +- .../AutoGen/Core/Agent/ConversableAgent.cs | 19 +- .../AutoGen/Core/Agent/DefaultReplyAgent.cs | 6 +- .../AutoGen/Core/Agent/GroupChatManager.cs | 6 +- dotnet/src/AutoGen/Core/Agent/IAgent.cs | 4 +- .../src/AutoGen/Core/Agent/MiddlewareAgent.cs | 12 +- .../Core/Agent/MiddlewareStreamingAgent.cs | 4 +- .../src/AutoGen/Core/Agent/UserProxyAgent.cs | 2 +- .../AutoGen/Core/Extension/AgentExtension.cs | 42 +-- .../Core/Extension/GroupChatExtension.cs | 32 +- .../Core/Extension/MessageExtension.cs | 68 ++++ .../Core/Extension/MiddlewareExtension.cs | 14 +- .../src/AutoGen/Core/GroupChat/GroupChat.cs | 24 +- .../Core/GroupChat/SequentialGroupChat.cs | 14 +- dotnet/src/AutoGen/Core/IGroupChat.cs | 4 +- .../src/AutoGen/Core/Message/TextMessage.cs | 2 +- .../AutoGen/Core/Message/ToolCallMessage.cs | 5 +- .../Core/Message/ToolCallResultMessage.cs | 5 +- .../Core/Middleware/DelegateMiddleware.cs | 6 +- .../Core/Middleware/FunctionCallMiddleware.cs | 6 +- .../Core/Middleware/HumanInputMiddleware.cs | 8 +- .../AutoGen/Core/Middleware/IMiddleware.cs | 2 +- .../Core/Middleware/MiddlewareContext.cs | 4 +- .../Core/Middleware/PrintMessageMiddleware.cs | 2 +- dotnet/src/AutoGen/Core/Workflow/Workflow.cs | 12 +- dotnet/src/AutoGen/OpenAI/GPTAgent.cs | 4 +- .../Middleware/OpenAIMessageConnector.cs | 329 ++++++++++++++++++ dotnet/test/AutoGen.Tests/EchoAgent.cs | 6 +- dotnet/test/AutoGen.Tests/MathClassTest.cs | 8 +- .../test/AutoGen.Tests/MiddlewareAgentTest.cs | 40 +-- dotnet/test/AutoGen.Tests/MiddlewareTest.cs | 20 +- .../AutoGen.Tests/RegisterReplyAgentTest.cs | 7 +- dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 23 +- dotnet/test/AutoGen.Tests/TwoAgentTest.cs | 4 +- dotnet/test/AutoGen.Tests/WorkflowTest.cs | 2 +- 47 files changed, 650 insertions(+), 229 deletions(-) create mode 100644 dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs index d2186eb7050..ae78cf91523 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs @@ -93,8 +93,10 @@ public async Task CodeSnippet4() }); var response = await assistantAgent.SendAsync("hello"); - response.FunctionName.Should().Be("UpperCase"); - response.Content.Should().BeNullOrEmpty(); + response.Should().BeOfType(); + var toolCallMessage = (ToolCallMessage)response; + toolCallMessage.ToolCalls.Count().Should().Be(1); + toolCallMessage.ToolCalls.First().FunctionName.Should().Be("UpperCase"); #endregion code_snippet_4 } @@ -130,7 +132,10 @@ public async Task CodeSnippet5() }); var response = await assistantAgent.SendAsync("hello"); - response.Content.Should().Be("HELLO"); + response.Should().BeOfType(); + response.From.Should().Be("assistant"); + var textMessage = (TextMessage)response; + textMessage.Content.Should().Be("HELLO"); #endregion code_snippet_5 } } diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs index a2973cb7874..6d37c574a4f 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs @@ -37,8 +37,11 @@ public async Task CodeSnippet4() }); var response = await assistantAgent.SendAsync("hello What's the weather in Seattle today? today is 2024-01-01"); - response.FunctionName.Should().Be("WeatherReport"); - response.FunctionArguments.Should().Be(@"{""location"":""Seattle"",""date"":""2024-01-01""}"); + response.Should().BeOfType(); + var toolCallMessage = (ToolCallMessage)response; + toolCallMessage.ToolCalls.Count().Should().Be(1); + toolCallMessage.ToolCalls[0].FunctionName.Should().Be("WeatherReport"); + toolCallMessage.ToolCalls[0].FunctionArguments.Should().Be(@"{""location"":""Seattle"",""date"":""2024-01-01""}"); #endregion code_snippet_4 } @@ -78,7 +81,9 @@ public async Task CodeSnippet6() #region code_snippet_6_1 var response = await assistantAgent.SendAsync("What's the weather in Seattle today? today is 2024-01-01"); - response.Content.Should().Be("Weather report for Seattle on 2024-01-01 is sunny"); + response.Should().BeOfType(); + var textMessage = (TextMessage)response; + textMessage.Content.Should().Be("Weather report for Seattle on 2024-01-01 is sunny"); #endregion code_snippet_6_1 } @@ -111,4 +116,4 @@ public async Task TwoAgentWeatherChatTestAsync() await user.InitiateChatAsync(assistant, "what's weather in Seattle today, today is 2024-01-01", 10); #endregion two_agent_weather_chat } -} \ No newline at end of file +} diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs index 69d74256ec3..997ee022ec3 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MiddlewareAgentCodeSnippet.cs @@ -20,54 +20,56 @@ public async Task CodeSnippet1() // Since no middleware is added, middlewareAgent will simply proxy into the inner agent to generate reply. var reply = await middlewareAgent.SendAsync("Hello World"); reply.From.Should().Be("assistant"); - reply.Content.Should().Be("Hello World"); + reply.GetContent().Should().Be("Hello World"); #endregion code_snippet_1 #region code_snippet_2 middlewareAgent.Use(async (messages, options, agent, ct) => { - var lastMessage = messages.Last(); + var lastMessage = messages.Last() as TextMessage; lastMessage.Content = $"[middleware 0] {lastMessage.Content}"; return await agent.GenerateReplyAsync(messages, options, ct); }); reply = await middlewareAgent.SendAsync("Hello World"); - reply.Content.Should().Be("[middleware 0] Hello World"); + reply.Should().BeOfType(); + var textReply = (TextMessage)reply; + textReply.Content.Should().Be("[middleware 0] Hello World"); #endregion code_snippet_2 #region code_snippet_2_1 middlewareAgent = agent.RegisterMiddleware(async (messages, options, agnet, ct) => { - var lastMessage = messages.Last(); + var lastMessage = messages.Last() as TextMessage; lastMessage.Content = $"[middleware 0] {lastMessage.Content}"; return await agent.GenerateReplyAsync(messages, options, ct); }); reply = await middlewareAgent.SendAsync("Hello World"); - reply.Content.Should().Be("[middleware 0] Hello World"); + reply.GetContent().Should().Be("[middleware 0] Hello World"); #endregion code_snippet_2_1 #region code_snippet_3 middlewareAgent.Use(async (messages, options, agent, ct) => { - var lastMessage = messages.Last(); + var lastMessage = messages.Last() as TextMessage; lastMessage.Content = $"[middleware 1] {lastMessage.Content}"; return await agent.GenerateReplyAsync(messages, options, ct); }); reply = await middlewareAgent.SendAsync("Hello World"); - reply.Content.Should().Be("[middleware 0] [middleware 1] Hello World"); + reply.GetContent().Should().Be("[middleware 0] [middleware 1] Hello World"); #endregion code_snippet_3 #region code_snippet_4 middlewareAgent.Use(async (messages, options, next, ct) => { - var lastMessage = messages.Last(); + var lastMessage = messages.Last() as TextMessage; lastMessage.Content = $"[middleware shortcut]"; return lastMessage; }); reply = await middlewareAgent.SendAsync("Hello World"); - reply.Content.Should().Be("[middleware shortcut]"); + reply.GetContent().Should().Be("[middleware shortcut]"); #endregion code_snippet_4 #region retrieve_inner_agent @@ -92,7 +94,7 @@ public async Task CodeSnippet1() var reply = await agent.GenerateReplyAsync(messages, options, ct); while (maxAttempt-- > 0) { - if (JsonSerializer.Deserialize>(reply.Content) is { } dict) + if (JsonSerializer.Deserialize>(reply.GetContent()) is { } dict) { return reply; } diff --git a/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs index 908256f4479..5a7c9612cea 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs @@ -30,16 +30,20 @@ public static async Task RunAsync() // talk to the assistant agent var reply = await assistantAgent.SendAsync("hello world"); - reply.Content?.Should().Be("HELLO WORLD"); + reply.Should().BeOfType(); + var textReply = (TextMessage)reply; + textReply.Content.Should().Be("HELLO WORLD"); // to carry on the conversation, pass the previous conversation history to the next call - var conversationHistory = new List + var conversationHistory = new List { - new Message(Role.User, "hello world"), // first message + new TextMessage(Role.User, "hello world"), // first message reply, // reply from assistant agent }; reply = await assistantAgent.SendAsync("hello world again", conversationHistory); - reply.Content?.Should().Be("HELLO WORLD AGAIN"); + reply.Should().BeOfType(); + textReply = (TextMessage)reply; + textReply.Content?.Should().Be("HELLO WORLD AGAIN"); } } diff --git a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs index af4877ea214..f847ed6d69f 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs @@ -27,7 +27,7 @@ public static async Task RunAsync() }) .RegisterPostProcess(async (_, reply, _) => { - if (reply.Content?.Contains("TERMINATE") is true) + if (reply.GetContent()?.Contains("TERMINATE") is true) { return new Message(Role.Assistant, GroupChatExtension.TERMINATE) { diff --git a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs index 2d43eaca616..74ad84d7c92 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs @@ -76,12 +76,19 @@ public async Task RunAsync() // talk to the assistant agent var upperCase = await agent.SendAsync("convert to upper case: hello world"); - upperCase.Content?.Should().Be("HELLO WORLD"); + upperCase.Should().BeOfType(); + var upperCaseResult = (ToolCallResultMessage)upperCase; + upperCaseResult.ToolCalls.First().Result?.Should().Be("HELLO WORLD"); + upperCaseResult.ToolCalls.Count().Should().Be(1); var concatString = await agent.SendAsync("concatenate strings: a, b, c, d, e"); - concatString.Content?.Should().Be("a b c d e"); + concatString.Should().BeOfType(); + var concatStringResult = (ToolCallResultMessage)concatString; + concatStringResult.ToolCalls.First().Result?.Should().Be("a b c d e"); var calculateTax = await agent.SendAsync("calculate tax: 100, 0.1"); - calculateTax.Content?.Should().Be("tax is 10"); + calculateTax.Should().BeOfType(); + var calculateTaxResult = (ToolCallResultMessage)calculateTax; + calculateTaxResult.ToolCalls.First().Result?.Should().Be("tax is 10"); } } diff --git a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs index c93a076f890..428677885c1 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs @@ -2,7 +2,6 @@ // Example05_Dalle_And_GPT4V.cs using AutoGen; -using AutoGen.OpenAI; using Azure.AI.OpenAI; using FluentAssertions; using autogen = AutoGen.LLMConfigAPI; @@ -21,7 +20,7 @@ public Example05_Dalle_And_GPT4V(OpenAIClient openAIClient) /// /// prompt with feedback /// - [FunctionAttribute] + [Function] public async Task GenerateImage(string prompt) { // TODO @@ -85,10 +84,10 @@ public static async Task RunAsync() .RegisterReply(async (msgs, ct) => { // if last message contains [TERMINATE], then find the last image url and terminate the conversation - if (msgs.Last().Content?.Contains("TERMINATE") is true) + if (msgs.Last().GetContent()?.Contains("TERMINATE") is true) { - var lastMessageWithImage = msgs.Last(msg => msg.Content?.Contains("IMAGE_GENERATION") is true); - var lastImageUrl = lastMessageWithImage.Content!.Split("\n").Last(); + var lastMessageWithImage = msgs.Last(msg => msg is ImageMessage) as ImageMessage; + var lastImageUrl = lastMessageWithImage.Url; Console.WriteLine($"download image from {lastImageUrl} to {imagePath}"); var httpClient = new HttpClient(); var imageBytes = await httpClient.GetByteArrayAsync(lastImageUrl); @@ -97,7 +96,7 @@ public static async Task RunAsync() var messageContent = $@"{GroupChatExtension.TERMINATE} {lastImageUrl}"; - return new Message(Role.Assistant, messageContent) + return new TextMessage(Role.Assistant, messageContent) { From = "dalle", }; @@ -125,7 +124,7 @@ public static async Task RunAsync() }).RegisterReply(async (msgs, ct) => { // if no image is generated, then ask DALL-E agent to generate image - if (msgs.Last().Content?.Contains("IMAGE_GENERATION") is false) + if (msgs.Last() is not ImageMessage) { return new Message(Role.Assistant, "Hey dalle, please generate image") { @@ -140,15 +139,12 @@ public static async Task RunAsync() // add image url to message metadata so it can be recognized by GPT-4V return msgs.Select(msg => { - if (msg.Content?.Contains("IMAGE_GENERATION") is true) + if (msg.GetContent() is string content && content.Contains("IMAGE_GENERATION")) { - var imageUrl = msg.Content.Split("\n").Last(); - var imageMessageItem = new ChatMessageImageContentItem(new Uri(imageUrl)); - var gpt4VMessage = new ChatRequestUserMessage(imageMessageItem); - var message = gpt4VMessage.ToMessage(); - message.From = msg.From; + var imageUrl = content.Split("\n").Last(); + var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: msg.From); - return message; + return imageMessage; } else { diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs index 9b6a8195b00..72c15ae035f 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs @@ -85,9 +85,9 @@ public static async Task RunAsync() var reply = await innerAgent.GenerateReplyAsync(msgs, option, ct); while (maxRetry-- > 0) { - if (reply.FunctionName == nameof(ReviewCodeBlock)) + if (reply is ToolCallResultMessage toolResultMessage && toolResultMessage.ToolCalls is { Count: 1 } && toolResultMessage.ToolCalls[0].FunctionName == nameof(ReviewCodeBlock)) { - var reviewResultObj = JsonSerializer.Deserialize(reply.Content!); + var reviewResultObj = JsonSerializer.Deserialize(toolResultMessage.ToolCalls[0].Result); var reviews = new List(); if (reviewResultObj.HasMultipleCodeBlocks) { @@ -122,13 +122,11 @@ public static async Task RunAsync() sb.AppendLine($"- {review}"); } - reply.Content = sb.ToString(); - - return reply; + return new TextMessage(Role.Assistant, sb.ToString(), from: "code_reviewer"); } else { - var msg = new Message(Role.Assistant, "The code looks good, please ask runner to run the code for you.") + var msg = new TextMessage(Role.Assistant, "The code looks good, please ask runner to run the code for you.") { From = "code_reviewer", }; @@ -138,7 +136,7 @@ public static async Task RunAsync() } else { - var originalContent = reply.Content; + var originalContent = reply.GetContent(); var prompt = $@"Please convert the content to ReviewCodeBlock function arguments. ## Original Content @@ -222,9 +220,11 @@ public static async Task RunAsync() }) .RegisterPostProcess(async (_, reply, _) => { - if (reply.Content?.Contains("TERMINATE") is true) + if (reply is TextMessage textMessage && textMessage.Content.Contains("TERMINATE") is true) { - reply.Content += $"\n\n {GroupChatExtension.TERMINATE}"; + var content = $"{textMessage.Content}\n\n {GroupChatExtension.TERMINATE}"; + + return new TextMessage(Role.Assistant, content, from: reply.From); } return reply; @@ -258,7 +258,9 @@ public static async Task RunAsync() var lastMessage = conversationHistory.Last(); lastMessage.From.Should().Be("admin"); lastMessage.IsGroupChatTerminateMessage().Should().BeTrue(); - lastMessage.Content.Should().Contain(the39thFibonacciNumber.ToString()); + lastMessage.Should().BeOfType(); + var textMessage = (TextMessage)lastMessage; + textMessage.Content.Should().Contain(the39thFibonacciNumber.ToString()); #endregion start_group_chat } } diff --git a/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs index 01957dee07a..4136c5639c6 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs @@ -90,7 +90,7 @@ public static async Task RunAsync() var reply = await innerAgent.GenerateReplyAsync(msgs, option, ct); // if reply is a function call, invoke function - var content = reply.Content; + var content = reply.GetContent(); try { if (JsonSerializer.Deserialize(content) is { } functionCall) diff --git a/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs b/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs index 34a2c593055..b0efcee2dba 100644 --- a/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs +++ b/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs @@ -27,13 +27,13 @@ public static class AgentExtension return agent.RegisterReply(async (msgs, ct) => { var lastMessage = msgs.LastOrDefault(); - if (lastMessage == null || lastMessage.Content is null) + if (lastMessage == null || lastMessage.GetContent() is null) { return null; } // retrieve all code blocks from last message - var codeBlocks = lastMessage.Content.Split(new[] { codeBlockPrefix }, StringSplitOptions.RemoveEmptyEntries); + var codeBlocks = lastMessage.GetContent()!.Split(new[] { codeBlockPrefix }, StringSplitOptions.RemoveEmptyEntries); if (codeBlocks.Length <= 0) { return null; diff --git a/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs b/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs index 14c8be0def2..ac5150cc9df 100644 --- a/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs +++ b/dotnet/src/AutoGen.LMStudio/LMStudioAgent.cs @@ -44,8 +44,8 @@ public class LMStudioAgent : IAgent public string Name => innerAgent.Name; - public Task GenerateReplyAsync( - IEnumerable messages, + public Task GenerateReplyAsync( + IEnumerable messages, GenerateReplyOptions? options = null, System.Threading.CancellationToken cancellationToken = default) { diff --git a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs index bde53ff4b6a..8a51d874bc8 100644 --- a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs +++ b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs @@ -36,7 +36,7 @@ public class SemanticKernelAgent : IStreamingAgent public string Name { get; } - public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { var chatMessageContents = ProcessMessage(messages); // if there's no system message in chatMessageContents, add one to the beginning diff --git a/dotnet/src/AutoGen/Core/Agent/AssistantAgent.cs b/dotnet/src/AutoGen/Core/Agent/AssistantAgent.cs index 0a302688372..06f65042add 100644 --- a/dotnet/src/AutoGen/Core/Agent/AssistantAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/AssistantAgent.cs @@ -14,7 +14,7 @@ public class AssistantAgent : ConversableAgent string name, string systemMessage = "You are a helpful AI assistant", ConversableAgentConfig? llmConfig = null, - Func, CancellationToken, Task>? isTermination = null, + Func, CancellationToken, Task>? isTermination = null, HumanInputMode humanInputMode = HumanInputMode.NEVER, IDictionary>>? functionMap = null, string? defaultReply = null) diff --git a/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs b/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs index bb8216babb4..838fef2e858 100644 --- a/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs @@ -43,7 +43,7 @@ public class ConversableAgent : IAgent IAgent? innerAgent = null, string? defaultAutoReply = null, HumanInputMode humanInputMode = HumanInputMode.NEVER, - Func, CancellationToken, Task>? isTermination = null, + Func, CancellationToken, Task>? isTermination = null, IDictionary>>? functionMap = null) { this.Name = name; @@ -59,7 +59,7 @@ public class ConversableAgent : IAgent string name, string systemMessage = "You are a helpful AI assistant", ConversableAgentConfig? llmConfig = null, - Func, CancellationToken, Task>? isTermination = null, + Func, CancellationToken, Task>? isTermination = null, HumanInputMode humanInputMode = HumanInputMode.AUTO, IDictionary>>? functionMap = null, string? defaultReply = null) @@ -98,17 +98,17 @@ public class ConversableAgent : IAgent public string Name { get; } - public Func, CancellationToken, Task>? IsTermination { get; } + public Func, CancellationToken, Task>? IsTermination { get; } - public async Task GenerateReplyAsync( - IEnumerable messages, + public async Task GenerateReplyAsync( + IEnumerable messages, GenerateReplyOptions? overrideOptions = null, CancellationToken cancellationToken = default) { // if there's no system message, add system message to the first of chat history - if (!messages.Any(m => m.Role == Role.System)) + if (!messages.Any(m => m.IsSystemMessage())) { - var systemMessage = new Message(Role.System, this.systemMessage, from: this.Name); + var systemMessage = new TextMessage(Role.System, this.systemMessage, from: this.Name); messages = new[] { systemMessage }.Concat(messages); } @@ -125,9 +125,8 @@ public class ConversableAgent : IAgent { if (m.From == this.Name) { - var clone = new Message(m); - clone.From = this.innerAgent.Name; - return clone; + m.From = this.innerAgent.Name; + return m; } else { diff --git a/dotnet/src/AutoGen/Core/Agent/DefaultReplyAgent.cs b/dotnet/src/AutoGen/Core/Agent/DefaultReplyAgent.cs index 82868c5c727..3c05440166c 100644 --- a/dotnet/src/AutoGen/Core/Agent/DefaultReplyAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/DefaultReplyAgent.cs @@ -21,11 +21,11 @@ public class DefaultReplyAgent : IAgent public string DefaultReply { get; } = string.Empty; - public async Task GenerateReplyAsync( - IEnumerable _, + public async Task GenerateReplyAsync( + IEnumerable _, GenerateReplyOptions? __ = null, CancellationToken ___ = default) { - return new Message(Role.Assistant, DefaultReply, from: this.Name); + return new TextMessage(Role.Assistant, DefaultReply, from: this.Name); } } diff --git a/dotnet/src/AutoGen/Core/Agent/GroupChatManager.cs b/dotnet/src/AutoGen/Core/Agent/GroupChatManager.cs index da38e253491..6c92041cc04 100644 --- a/dotnet/src/AutoGen/Core/Agent/GroupChatManager.cs +++ b/dotnet/src/AutoGen/Core/Agent/GroupChatManager.cs @@ -17,12 +17,12 @@ public GroupChatManager(IGroupChat groupChat) } public string Name => throw new ArgumentException("GroupChatManager does not have a name"); - public IEnumerable? Messages { get; private set; } + public IEnumerable? Messages { get; private set; } public IGroupChat GroupChat { get; } - public async Task GenerateReplyAsync( - IEnumerable messages, + public async Task GenerateReplyAsync( + IEnumerable messages, GenerateReplyOptions? options, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/AutoGen/Core/Agent/IAgent.cs b/dotnet/src/AutoGen/Core/Agent/IAgent.cs index 6101381da80..2450bee6f34 100644 --- a/dotnet/src/AutoGen/Core/Agent/IAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/IAgent.cs @@ -18,8 +18,8 @@ public interface IAgent /// /// conversation history /// completion option. If provided, it should override existing option if there's any - public Task GenerateReplyAsync( - IEnumerable messages, + public Task GenerateReplyAsync( + IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs b/dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs index c6193ba871d..3c4e99e9b2f 100644 --- a/dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs @@ -50,8 +50,8 @@ public MiddlewareAgent(MiddlewareAgent other) /// public IEnumerable Middlewares => this.middlewares; - public Task GenerateReplyAsync( - IEnumerable messages, + public Task GenerateReplyAsync( + IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { @@ -69,7 +69,7 @@ public MiddlewareAgent(MiddlewareAgent other) /// Call into the next function to continue the execution of the next middleware. /// Short cut middleware execution by not calling into the next function. /// - public void Use(Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, string? middlewareName = null) + public void Use(Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, string? middlewareName = null) { this.middlewares.Add(new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) => { @@ -77,7 +77,7 @@ public void Use(Func, GenerateReplyOptions?, IAgent, Cancel })); } - public void Use(Func> func, string? middlewareName = null) + public void Use(Func> func, string? middlewareName = null) { this.middlewares.Add(new DelegateMiddleware(middlewareName, func)); } @@ -100,8 +100,8 @@ public DelegateAgent(IMiddleware middleware, IAgent innerAgent) public string Name { get => this.innerAgent.Name; } - public Task GenerateReplyAsync( - IEnumerable messages, + public Task GenerateReplyAsync( + IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs index 4d9b35f7825..92536b3259e 100644 --- a/dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs @@ -36,7 +36,7 @@ public MiddlewareStreamingAgent(IStreamingAgent agent, string? name = null, IEnu /// public IEnumerable Middlewares => _middlewares; - public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { throw new NotImplementedException("Streaming agent does not support non-streaming reply."); } @@ -75,7 +75,7 @@ public DelegateStreamingAgent(IStreamingMiddleware middleware, IStreamingAgent n this.innerAgent = next; } - public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { throw new NotImplementedException("Streaming agent does not support non-streaming reply."); } diff --git a/dotnet/src/AutoGen/Core/Agent/UserProxyAgent.cs b/dotnet/src/AutoGen/Core/Agent/UserProxyAgent.cs index 5bcee19c1fd..a48f07006b8 100644 --- a/dotnet/src/AutoGen/Core/Agent/UserProxyAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/UserProxyAgent.cs @@ -14,7 +14,7 @@ public class UserProxyAgent : ConversableAgent string name, string systemMessage = "You are a helpful AI assistant", ConversableAgentConfig? llmConfig = null, - Func, CancellationToken, Task>? isTermination = null, + Func, CancellationToken, Task>? isTermination = null, HumanInputMode humanInputMode = HumanInputMode.ALWAYS, IDictionary>>? functionMap = null, string? defaultReply = null) diff --git a/dotnet/src/AutoGen/Core/Extension/AgentExtension.cs b/dotnet/src/AutoGen/Core/Extension/AgentExtension.cs index 3f05d3d8e32..a2463674b73 100644 --- a/dotnet/src/AutoGen/Core/Extension/AgentExtension.cs +++ b/dotnet/src/AutoGen/Core/Extension/AgentExtension.cs @@ -17,13 +17,13 @@ public static class AgentExtension /// sender agent. /// chat history. /// conversation history - public static async Task SendAsync( + public static async Task SendAsync( this IAgent agent, - Message? message = null, - IEnumerable? chatHistory = null, + IMessage? message = null, + IEnumerable? chatHistory = null, CancellationToken ct = default) { - var messages = new List(); + var messages = new List(); if (chatHistory != null) { @@ -48,13 +48,13 @@ public static class AgentExtension /// message to send. will be added to the end of if provided /// chat history. /// conversation history - public static async Task SendAsync( + public static async Task SendAsync( this IAgent agent, string message, - IEnumerable? chatHistory = null, + IEnumerable? chatHistory = null, CancellationToken ct = default) { - var msg = new Message(Role.User, message); + var msg = new TextMessage(Role.User, message); return await agent.SendAsync(msg, chatHistory, ct); } @@ -67,10 +67,10 @@ public static class AgentExtension /// chat history. /// max conversation round. /// conversation history - public static async Task> SendAsync( + public static async Task> SendAsync( this IAgent agent, IAgent receiver, - IEnumerable chatHistory, + IEnumerable chatHistory, int maxRound = 10, CancellationToken ct = default) { @@ -100,20 +100,20 @@ public static class AgentExtension /// chat history. /// max conversation round. /// conversation history - public static async Task> SendAsync( + public static async Task> SendAsync( this IAgent agent, IAgent receiver, string message, - IEnumerable? chatHistory = null, + IEnumerable? chatHistory = null, int maxRound = 10, CancellationToken ct = default) { - var msg = new Message(Role.User, message) + var msg = new TextMessage(Role.User, message) { From = agent.Name, }; - chatHistory = chatHistory ?? new List(); + chatHistory = chatHistory ?? new List(); chatHistory = chatHistory.Append(msg); return await agent.SendAsync(receiver, chatHistory, maxRound, ct); @@ -126,17 +126,17 @@ public static class AgentExtension /// receiver agent /// message to send /// max round - public static async Task> InitiateChatAsync( + public static async Task> InitiateChatAsync( this IAgent agent, IAgent receiver, string? message = null, int maxRound = 10, CancellationToken ct = default) { - var chatHistory = new List(); + var chatHistory = new List(); if (message != null) { - var msg = new Message(Role.User, message) + var msg = new TextMessage(Role.User, message) { From = agent.Name, }; @@ -147,25 +147,25 @@ public static class AgentExtension return await agent.SendAsync(receiver, chatHistory, maxRound, ct); } - public static async Task> SendMessageToGroupAsync( + public static async Task> SendMessageToGroupAsync( this IAgent agent, IGroupChat groupChat, string msg, - IEnumerable? chatHistory = null, + IEnumerable? chatHistory = null, int maxRound = 10, CancellationToken ct = default) { - var chatMessage = new Message(Role.Assistant, msg, from: agent.Name); + var chatMessage = new TextMessage(Role.Assistant, msg, from: agent.Name); chatHistory = chatHistory ?? Enumerable.Empty(); chatHistory = chatHistory.Append(chatMessage); return await agent.SendMessageToGroupAsync(groupChat, chatHistory, maxRound, ct); } - public static async Task> SendMessageToGroupAsync( + public static async Task> SendMessageToGroupAsync( this IAgent _, IGroupChat groupChat, - IEnumerable? chatHistory = null, + IEnumerable? chatHistory = null, int maxRound = 10, CancellationToken ct = default) { diff --git a/dotnet/src/AutoGen/Core/Extension/GroupChatExtension.cs b/dotnet/src/AutoGen/Core/Extension/GroupChatExtension.cs index 6bca736d3c3..5a48c2e83cd 100644 --- a/dotnet/src/AutoGen/Core/Extension/GroupChatExtension.cs +++ b/dotnet/src/AutoGen/Core/Extension/GroupChatExtension.cs @@ -13,7 +13,7 @@ public static class GroupChatExtension public static void AddInitializeMessage(this IAgent agent, string message, IGroupChat groupChat) { - var msg = new Message(Role.User, message) + var msg = new TextMessage(Role.User, message) { From = agent.Name }; @@ -21,9 +21,9 @@ public static void AddInitializeMessage(this IAgent agent, string message, IGrou groupChat.AddInitializeMessage(msg); } - public static IEnumerable MessageToKeep( + public static IEnumerable MessageToKeep( this IGroupChat _, - IEnumerable messages) + IEnumerable messages) { var lastCLRMessageIndex = messages.ToList() .FindLastIndex(x => x.IsGroupChatClearMessage()); @@ -49,33 +49,33 @@ public static void AddInitializeMessage(this IAgent agent, string message, IGrou } /// - /// Return true if contains , otherwise false. + /// Return true if contains , otherwise false. /// /// /// - public static bool IsGroupChatTerminateMessage(this Message message) + public static bool IsGroupChatTerminateMessage(this IMessage message) { - return message.Content?.Contains(TERMINATE) ?? false; + return message.GetContent()?.Contains(TERMINATE) ?? false; } - public static bool IsGroupChatClearMessage(this Message message) + public static bool IsGroupChatClearMessage(this IMessage message) { - return message.Content?.Contains(CLEAR_MESSAGES) ?? false; + return message.GetContent()?.Contains(CLEAR_MESSAGES) ?? false; } - public static IEnumerable ProcessConversationForAgent( + public static IEnumerable ProcessConversationForAgent( this IGroupChat groupChat, - IEnumerable initialMessages, - IEnumerable messages) + IEnumerable initialMessages, + IEnumerable messages) { messages = groupChat.MessageToKeep(messages); return initialMessages.Concat(messages); } - internal static IEnumerable ProcessConversationsForRolePlay( + internal static IEnumerable ProcessConversationsForRolePlay( this IGroupChat groupChat, - IEnumerable initialMessages, - IEnumerable messages) + IEnumerable initialMessages, + IEnumerable messages) { messages = groupChat.MessageToKeep(messages); var messagesToKeep = initialMessages.Concat(messages); @@ -83,12 +83,12 @@ public static bool IsGroupChatClearMessage(this Message message) return messagesToKeep.Select((x, i) => { var msg = @$"From {x.From}: -{x.Content} +{x.GetContent()} round # {i}"; - return new Message(Role.User, content: msg); + return new TextMessage(Role.User, content: msg); }); } } diff --git a/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs b/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs index 52205bf67ec..da84ab28d23 100644 --- a/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs +++ b/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // MessageExtension.cs +using System.Collections.Generic; +using System.Linq; using System.Text; namespace AutoGen; @@ -8,6 +10,16 @@ namespace AutoGen; public static class MessageExtension { private static string separator = new string('-', 20); + + public static string FormatMessage(this IMessage message) + { + return message switch + { + Message msg => msg.FormatMessage(), + _ => message.ToString(), + }; + } + public static string FormatMessage(this Message message) { var sb = new StringBuilder(); @@ -41,4 +53,60 @@ public static string FormatMessage(this Message message) return sb.ToString(); } + + public static bool IsSystemMessage(this IMessage message) + { + return message switch + { + TextMessage textMessage => textMessage.Role == Role.System, + Message msg => msg.Role == Role.System, + _ => false, + }; + } + + /// + /// Get the content from the message + /// if the message is a or , return the content + /// if the message is a and only contains one function call, return the result of that function call + /// + /// + public static string? GetContent(this IMessage message) + { + return message switch + { + TextMessage textMessage => textMessage.Content, + Message msg => msg.Content, + ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.GetToolCalls().Count() == 1 ? toolCallResultMessage.GetToolCalls().First().Result : null, + _ => null, + }; + } + + /// + /// Get the role from the message if it's available. + /// + public static Role? GetRole(this IMessage message) + { + return message switch + { + TextMessage textMessage => textMessage.Role, + Message msg => msg.Role, + ImageMessage img => img.Role, + MultiModalMessage multiModal => multiModal.Role, + _ => null, + }; + } + + public static IEnumerable GetToolCalls(this IMessage message) + { + return message switch + { + ToolCallMessage toolCallMessage => toolCallMessage.ToolCalls, + ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls, + Message msg => msg.FunctionName is not null && msg.FunctionArguments is not null + ? msg.Content is not null ? new List { new ToolCall(msg.FunctionName, msg.FunctionArguments, result: msg.Content) } + : new List { new ToolCall(msg.FunctionName, msg.FunctionArguments) } + : [], + _ => [], + }; + } } diff --git a/dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs b/dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs index c01dae463e7..a3314d5a368 100644 --- a/dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs +++ b/dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs @@ -23,7 +23,7 @@ public static class MiddlewareExtension /// throw when agent name is null. public static MiddlewareAgent RegisterReply( this TAgent agent, - Func, CancellationToken, Task> replyFunc) + Func, CancellationToken, Task> replyFunc) where TAgent : IAgent { return agent.RegisterMiddleware(async (messages, options, agent, ct) => @@ -61,7 +61,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// throw when agent name is null. public static MiddlewareAgent RegisterPostProcess( this TAgent agent, - Func, Message, CancellationToken, Task> postprocessFunc) + Func, IMessage, CancellationToken, Task> postprocessFunc) where TAgent : IAgent { return agent.RegisterMiddleware(async (messages, options, agent, ct) => @@ -78,7 +78,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// throw when agent name is null. public static IAgent RegisterPreProcess( this IAgent agent, - Func, CancellationToken, Task>> preprocessFunc) + Func, CancellationToken, Task>> preprocessFunc) { return agent.RegisterMiddleware(async (messages, options, agent, ct) => { @@ -93,7 +93,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// public static MiddlewareAgent RegisterMiddleware( this TAgent agent, - Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, + Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, string? middlewareName = null) where TAgent : IAgent { @@ -133,7 +133,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// public static MiddlewareAgent RegisterMiddleware( this MiddlewareAgent agent, - Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, + Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, string? middlewareName = null) where TAgent : IAgent { @@ -191,7 +191,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// public static MiddlewareStreamingAgent RegisterMiddleware( this TAgent agent, - Func>> func, + Func>> func, string? middlewareName = null) where TAgent : IStreamingAgent { @@ -206,7 +206,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// public static MiddlewareStreamingAgent RegisterMiddleware( this MiddlewareStreamingAgent agent, - Func>> func, + Func>> func, string? middlewareName = null) where TAgent : IStreamingAgent { diff --git a/dotnet/src/AutoGen/Core/GroupChat/GroupChat.cs b/dotnet/src/AutoGen/Core/GroupChat/GroupChat.cs index 514556d4d6f..24b2ad00894 100644 --- a/dotnet/src/AutoGen/Core/GroupChat/GroupChat.cs +++ b/dotnet/src/AutoGen/Core/GroupChat/GroupChat.cs @@ -13,10 +13,10 @@ public class GroupChat : IGroupChat { private IAgent? admin; private List agents = new List(); - private IEnumerable initializeMessages = new List(); + private IEnumerable initializeMessages = new List(); private Workflow? workflow = null; - public IEnumerable? Messages { get; private set; } + public IEnumerable? Messages { get; private set; } /// /// Create a group chat. The next speaker will be decided by a combination effort of the admin and the workflow. @@ -28,12 +28,12 @@ public class GroupChat : IGroupChat public GroupChat( IEnumerable members, IAgent? admin, - IEnumerable? initializeMessages = null, + IEnumerable? initializeMessages = null, Workflow? workflow = null) { this.admin = admin; this.agents = members.ToList(); - this.initializeMessages = initializeMessages ?? new List(); + this.initializeMessages = initializeMessages ?? new List(); this.workflow = workflow; this.Validation(); @@ -81,7 +81,7 @@ private void Validation() /// current speaker /// conversation history /// next speaker. - public async Task SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumerable conversationHistory) + public async Task SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumerable conversationHistory) { var agentNames = this.agents.Select(x => x.Name).ToList(); if (this.workflow != null) @@ -104,7 +104,7 @@ public async Task SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumera throw new Exception("No admin is provided."); } - var systemMessage = new Message(Role.System, + var systemMessage = new TextMessage(Role.System, content: $@"You are in a role play game. Carefully read the conversation history and carry on the conversation. The available roles are: {string.Join(",", agentNames)} @@ -115,7 +115,7 @@ public async Task SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumera var conv = this.ProcessConversationsForRolePlay(this.initializeMessages, conversationHistory); - var messages = new Message[] { systemMessage }.Concat(conv); + var messages = new IMessage[] { systemMessage }.Concat(conv); var response = await this.admin.GenerateReplyAsync( messages: messages, options: new GenerateReplyOptions @@ -126,24 +126,24 @@ public async Task SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumera Functions = [], }); - var name = response?.Content ?? throw new Exception("No name is returned."); + var name = response?.GetContent() ?? throw new Exception("No name is returned."); // remove From name = name!.Substring(5); return this.agents.First(x => x.Name!.ToLower() == name.ToLower()); } - public void AddInitializeMessage(Message message) + public void AddInitializeMessage(IMessage message) { this.initializeMessages = this.initializeMessages.Append(message); } - public async Task> CallAsync( - IEnumerable? conversationWithName = null, + public async Task> CallAsync( + IEnumerable? conversationWithName = null, int maxRound = 10, CancellationToken ct = default) { - var conversationHistory = new List(); + var conversationHistory = new List(); if (conversationWithName != null) { conversationHistory.AddRange(conversationWithName); diff --git a/dotnet/src/AutoGen/Core/GroupChat/SequentialGroupChat.cs b/dotnet/src/AutoGen/Core/GroupChat/SequentialGroupChat.cs index a912ef9d792..9a5dcbe4930 100644 --- a/dotnet/src/AutoGen/Core/GroupChat/SequentialGroupChat.cs +++ b/dotnet/src/AutoGen/Core/GroupChat/SequentialGroupChat.cs @@ -12,27 +12,27 @@ namespace AutoGen; public class SequentialGroupChat : IGroupChat { private readonly List agents = new List(); - private readonly List initializeMessages = new List(); + private readonly List initializeMessages = new List(); public SequentialGroupChat( IEnumerable agents, - List? initializeMessages = null) + List? initializeMessages = null) { this.agents.AddRange(agents); - this.initializeMessages = initializeMessages ?? new List(); + this.initializeMessages = initializeMessages ?? new List(); } - public void AddInitializeMessage(Message message) + public void AddInitializeMessage(IMessage message) { this.initializeMessages.Add(message); } - public async Task> CallAsync( - IEnumerable? conversationWithName = null, + public async Task> CallAsync( + IEnumerable? conversationWithName = null, int maxRound = 10, CancellationToken ct = default) { - var conversationHistory = new List(); + var conversationHistory = new List(); if (conversationWithName != null) { conversationHistory.AddRange(conversationWithName); diff --git a/dotnet/src/AutoGen/Core/IGroupChat.cs b/dotnet/src/AutoGen/Core/IGroupChat.cs index 3ce46c117a2..4c0be66f4a2 100644 --- a/dotnet/src/AutoGen/Core/IGroupChat.cs +++ b/dotnet/src/AutoGen/Core/IGroupChat.cs @@ -9,7 +9,7 @@ namespace AutoGen; public interface IGroupChat { - void AddInitializeMessage(Message message); + void AddInitializeMessage(IMessage message); - Task> CallAsync(IEnumerable? conversation = null, int maxRound = 10, CancellationToken ct = default); + Task> CallAsync(IEnumerable? conversation = null, int maxRound = 10, CancellationToken ct = default); } diff --git a/dotnet/src/AutoGen/Core/Message/TextMessage.cs b/dotnet/src/AutoGen/Core/Message/TextMessage.cs index 288234830a9..183907bd5a7 100644 --- a/dotnet/src/AutoGen/Core/Message/TextMessage.cs +++ b/dotnet/src/AutoGen/Core/Message/TextMessage.cs @@ -14,7 +14,7 @@ public TextMessage(Role role, string content, string? from = null) public Role Role { get; set; } - public string Content { get; } + public string Content { get; set; } public string? From { get; set; } diff --git a/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs b/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs index 81656c014ee..1fb291a6551 100644 --- a/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs +++ b/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs @@ -2,6 +2,7 @@ // ToolCallMessage.cs using System.Collections.Generic; +using System.Linq; using System.Text; namespace AutoGen; @@ -38,7 +39,7 @@ public class ToolCallMessage : IMessage public ToolCallMessage(IEnumerable toolCalls, string? from = null) { this.From = from; - this.ToolCalls = toolCalls; + this.ToolCalls = toolCalls.ToList(); } public ToolCallMessage(string functionName, string functionArgs, string? from = null) @@ -47,7 +48,7 @@ public ToolCallMessage(string functionName, string functionArgs, string? from = this.ToolCalls = new List { new ToolCall(functionName, functionArgs) }; } - public IEnumerable ToolCalls { get; set; } + public IList ToolCalls { get; set; } public string? From { get; set; } diff --git a/dotnet/src/AutoGen/Core/Message/ToolCallResultMessage.cs b/dotnet/src/AutoGen/Core/Message/ToolCallResultMessage.cs index 0b4fe3d4523..4caa87ea933 100644 --- a/dotnet/src/AutoGen/Core/Message/ToolCallResultMessage.cs +++ b/dotnet/src/AutoGen/Core/Message/ToolCallResultMessage.cs @@ -2,6 +2,7 @@ // ToolCallResultMessage.cs using System.Collections.Generic; +using System.Linq; using System.Text; namespace AutoGen; @@ -11,7 +12,7 @@ public class ToolCallResultMessage : IMessage public ToolCallResultMessage(IEnumerable toolCalls, string? from = null) { this.From = from; - this.ToolCalls = toolCalls; + this.ToolCalls = toolCalls.ToList(); } public ToolCallResultMessage(string result, string functionName, string functionArgs, string? from = null) @@ -25,7 +26,7 @@ public ToolCallResultMessage(string result, string functionName, string function /// /// The original tool call message /// - public IEnumerable ToolCalls { get; set; } + public IList ToolCalls { get; set; } public string? From { get; set; } diff --git a/dotnet/src/AutoGen/Core/Middleware/DelegateMiddleware.cs b/dotnet/src/AutoGen/Core/Middleware/DelegateMiddleware.cs index bc34af3c38f..95e9a81dfe3 100644 --- a/dotnet/src/AutoGen/Core/Middleware/DelegateMiddleware.cs +++ b/dotnet/src/AutoGen/Core/Middleware/DelegateMiddleware.cs @@ -13,14 +13,14 @@ internal class DelegateMiddleware : IMiddleware /// middleware delegate. Call into the next function to continue the execution of the next middleware. Otherwise, short cut the middleware execution. /// /// cancellation token - public delegate Task MiddlewareDelegate( + public delegate Task MiddlewareDelegate( MiddlewareContext context, IAgent agent, CancellationToken cancellationToken); private readonly MiddlewareDelegate middlewareDelegate; - public DelegateMiddleware(string? name, Func> middlewareDelegate) + public DelegateMiddleware(string? name, Func> middlewareDelegate) { this.Name = name; this.middlewareDelegate = async (context, agent, cancellationToken) => @@ -31,7 +31,7 @@ public DelegateMiddleware(string? name, Func InvokeAsync( + public Task InvokeAsync( MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) diff --git a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs index 0da541d6c72..e2e00c1e684 100644 --- a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs +++ b/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs @@ -29,11 +29,11 @@ public class FunctionCallMiddleware : IMiddleware public string? Name { get; } - public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) + public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { // if the last message is a function call message, invoke the function and return the result instead of sending to the agent. var lastMessage = context.Messages.Last(); - if (lastMessage is not null && lastMessage is { Content: null, FunctionName: string functionName, FunctionArguments: string functionArguments }) + if (lastMessage is Message msg && msg is { Content: null, FunctionName: string functionName, FunctionArguments: string functionArguments }) { if (this.functionMap?.TryGetValue(functionName, out var func) is true) { @@ -68,7 +68,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, var reply = await agent.GenerateReplyAsync(context.Messages, options, cancellationToken); // if the reply is a function call message plus the function's name is available in function map, invoke the function and return the result instead of sending to the agent. - if (reply is { FunctionName: string fName, FunctionArguments: string fArgs }) + if (reply is Message message && message is { FunctionName: string fName, FunctionArguments: string fArgs }) { if (this.functionMap?.TryGetValue(fName, out var func) is true) { diff --git a/dotnet/src/AutoGen/Core/Middleware/HumanInputMiddleware.cs b/dotnet/src/AutoGen/Core/Middleware/HumanInputMiddleware.cs index 68af4995d19..9295593f55a 100644 --- a/dotnet/src/AutoGen/Core/Middleware/HumanInputMiddleware.cs +++ b/dotnet/src/AutoGen/Core/Middleware/HumanInputMiddleware.cs @@ -18,7 +18,7 @@ public class HumanInputMiddleware : IMiddleware private readonly HumanInputMode mode; private readonly string prompt; private readonly string exitKeyword; - private Func, CancellationToken, Task> isTermination; + private Func, CancellationToken, Task> isTermination; private Func getInput = Console.ReadLine; private Action writeLine = Console.WriteLine; public string? Name => nameof(HumanInputMiddleware); @@ -27,7 +27,7 @@ public class HumanInputMiddleware : IMiddleware string prompt = "Please give feedback: Press enter or type 'exit' to stop the conversation.", string exitKeyword = "exit", HumanInputMode mode = HumanInputMode.AUTO, - Func, CancellationToken, Task>? isTermination = null, + Func, CancellationToken, Task>? isTermination = null, Func? getInput = null, Action? writeLine = null) { @@ -39,7 +39,7 @@ public class HumanInputMiddleware : IMiddleware this.writeLine = writeLine ?? WriteLine; } - public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) + public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { // if the mode is never, then just return the input message if (mode == HumanInputMode.NEVER) @@ -81,7 +81,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, throw new InvalidOperationException("Invalid mode"); } - private async Task DefaultIsTermination(IEnumerable messages, CancellationToken _) + private async Task DefaultIsTermination(IEnumerable messages, CancellationToken _) { return messages?.Last().IsGroupChatTerminateMessage() is true; } diff --git a/dotnet/src/AutoGen/Core/Middleware/IMiddleware.cs b/dotnet/src/AutoGen/Core/Middleware/IMiddleware.cs index 647fd2b2f95..461aaebb7f7 100644 --- a/dotnet/src/AutoGen/Core/Middleware/IMiddleware.cs +++ b/dotnet/src/AutoGen/Core/Middleware/IMiddleware.cs @@ -19,7 +19,7 @@ public interface IMiddleware /// /// The method to invoke the middleware /// - public Task InvokeAsync( + public Task InvokeAsync( MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default); diff --git a/dotnet/src/AutoGen/Core/Middleware/MiddlewareContext.cs b/dotnet/src/AutoGen/Core/Middleware/MiddlewareContext.cs index 5895001aae1..9d2b29787b6 100644 --- a/dotnet/src/AutoGen/Core/Middleware/MiddlewareContext.cs +++ b/dotnet/src/AutoGen/Core/Middleware/MiddlewareContext.cs @@ -8,7 +8,7 @@ namespace AutoGen.Core.Middleware; public class MiddlewareContext { public MiddlewareContext( - IEnumerable messages, + IEnumerable messages, GenerateReplyOptions? options) { this.Messages = messages; @@ -18,7 +18,7 @@ public class MiddlewareContext /// /// Messages to send to the agent /// - public IEnumerable Messages { get; } + public IEnumerable Messages { get; } /// /// Options to generate the reply diff --git a/dotnet/src/AutoGen/Core/Middleware/PrintMessageMiddleware.cs b/dotnet/src/AutoGen/Core/Middleware/PrintMessageMiddleware.cs index 35b05cb3104..33a6118f1c4 100644 --- a/dotnet/src/AutoGen/Core/Middleware/PrintMessageMiddleware.cs +++ b/dotnet/src/AutoGen/Core/Middleware/PrintMessageMiddleware.cs @@ -14,7 +14,7 @@ public class PrintMessageMiddleware : IMiddleware { public string? Name => nameof(PrintMessageMiddleware); - public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) + public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { var reply = await agent.GenerateReplyAsync(context.Messages, context.Options, cancellationToken); diff --git a/dotnet/src/AutoGen/Core/Workflow/Workflow.cs b/dotnet/src/AutoGen/Core/Workflow/Workflow.cs index fdba42e0234..83a795f403c 100644 --- a/dotnet/src/AutoGen/Core/Workflow/Workflow.cs +++ b/dotnet/src/AutoGen/Core/Workflow/Workflow.cs @@ -33,7 +33,7 @@ public void AddTransition(Transition transition) /// the from agent /// messages /// A list of agents that the messages can be transit to - public async Task> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable messages) + public async Task> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable messages) { var nextAgents = new List(); var availableTransitions = transitions.FindAll(t => t.From == fromAgent) ?? Enumerable.Empty(); @@ -56,17 +56,17 @@ public class Transition { private readonly IAgent _from; private readonly IAgent _to; - private readonly Func, Task>? _canTransition; + private readonly Func, Task>? _canTransition; /// /// Create a new instance of . /// This constructor is used for testing purpose only. - /// To create a new instance of , use . + /// To create a new instance of , use . /// /// from agent /// to agent /// detect if the transition is allowed, default to be always true - internal Transition(IAgent from, IAgent to, Func, Task>? canTransitionAsync = null) + internal Transition(IAgent from, IAgent to, Func, Task>? canTransitionAsync = null) { _from = from; _to = to; @@ -77,7 +77,7 @@ internal Transition(IAgent from, IAgent to, Func. /// /// " - public static Transition Create(TFromAgent from, TToAgent to, Func, Task>? canTransitionAsync = null) + public static Transition Create(TFromAgent from, TToAgent to, Func, Task>? canTransitionAsync = null) where TFromAgent : IAgent where TToAgent : IAgent { @@ -92,7 +92,7 @@ internal Transition(IAgent from, IAgent to, Func /// messages - public Task CanTransitionAsync(IEnumerable messages) + public Task CanTransitionAsync(IEnumerable messages) { if (_canTransition == null) { diff --git a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs index e85f1979f5e..a5df636935d 100644 --- a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs +++ b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs @@ -74,8 +74,8 @@ public class GPTAgent : IStreamingAgent public string Name { get; } - public async Task GenerateReplyAsync( - IEnumerable messages, + public async Task GenerateReplyAsync( + IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs b/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs new file mode 100644 index 00000000000..213fcba9a19 --- /dev/null +++ b/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs @@ -0,0 +1,329 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OpenAIMessageConnector.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.Core.Middleware; +using Azure.AI.OpenAI; + +namespace AutoGen.OpenAI.Middleware; + +/// +/// This middleware converts the incoming to before sending to agent. And converts the output to after receiving from agent. +/// Supported are +/// - +/// - +/// - +/// - +/// - +/// - +/// - where T is +/// - where TMessage1 is and TMessage2 is +/// +public class OpenAIMessageConnector : IMiddleware, IStreamingMiddleware +{ + public string? Name => nameof(OpenAIMessageConnector); + + public Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) + { + var chatMessages = ProcessIncomingMessages(agent, context.Messages) + .Select(m => new MessageEnvelope(m)); + + return agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken); + } + + public Task> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + private IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages) + { + return messages.SelectMany(m => + { + if (m.From == null) + { + return ProcessIncomingMessagesWithEmptyFrom(m); + } + else if (m.From == agent.Name) + { + return ProcessIncomingMessagesForSelf(m); + } + else + { + return ProcessIncomingMessagesForOther(m); + } + }); + } + + private IEnumerable ProcessIncomingMessagesForSelf(IMessage message) + { + return message switch + { + TextMessage textMessage => ProcessIncomingMessagesForSelf(textMessage), + ImageMessage imageMessage => ProcessIncomingMessagesForSelf(imageMessage), + MultiModalMessage multiModalMessage => ProcessIncomingMessagesForSelf(multiModalMessage), + ToolCallMessage toolCallMessage => ProcessIncomingMessagesForSelf(toolCallMessage), + ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesForSelf(toolCallResultMessage), + Message msg => ProcessIncomingMessagesForSelf(msg), + IMessage crm => ProcessIncomingMessagesForSelf(crm), + AggregateMessage aggregateMessage => ProcessIncomingMessagesForSelf(aggregateMessage), + _ => throw new NotImplementedException(), + }; + } + + private IEnumerable ProcessIncomingMessagesWithEmptyFrom(IMessage message) + { + return message switch + { + TextMessage textMessage => ProcessIncomingMessagesWithEmptyFrom(textMessage), + ImageMessage imageMessage => ProcessIncomingMessagesWithEmptyFrom(imageMessage), + MultiModalMessage multiModalMessage => ProcessIncomingMessagesWithEmptyFrom(multiModalMessage), + ToolCallMessage toolCallMessage => ProcessIncomingMessagesWithEmptyFrom(toolCallMessage), + ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesWithEmptyFrom(toolCallResultMessage), + Message msg => ProcessIncomingMessagesWithEmptyFrom(msg), + IMessage crm => ProcessIncomingMessagesWithEmptyFrom(crm), + AggregateMessage aggregateMessage => ProcessIncomingMessagesWithEmptyFrom(aggregateMessage), + _ => throw new NotImplementedException(), + }; + } + + private IEnumerable ProcessIncomingMessagesForOther(IMessage message) + { + return message switch + { + TextMessage textMessage => ProcessIncomingMessagesForOther(textMessage), + ImageMessage imageMessage => ProcessIncomingMessagesForOther(imageMessage), + MultiModalMessage multiModalMessage => ProcessIncomingMessagesForOther(multiModalMessage), + ToolCallMessage toolCallMessage => ProcessIncomingMessagesForOther(toolCallMessage), + ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesForOther(toolCallResultMessage), + Message msg => ProcessIncomingMessagesForOther(msg), + IMessage crm => ProcessIncomingMessagesForOther(crm), + AggregateMessage aggregateMessage => ProcessIncomingMessagesForOther(aggregateMessage), + _ => throw new NotImplementedException(), + }; + } + + private IEnumerable ProcessIncomingMessagesForSelf(TextMessage message) + { + if (message.Role == Role.System) + { + return new[] { new ChatRequestSystemMessage(message.Content) }; + } + else + { + return new[] { new ChatRequestAssistantMessage(message.Content) }; + } + } + + private IEnumerable ProcessIncomingMessagesForSelf(ImageMessage _) + { + return [new ChatRequestAssistantMessage("// Image Message is not supported")]; + } + + private IEnumerable ProcessIncomingMessagesForSelf(MultiModalMessage _) + { + return [new ChatRequestAssistantMessage("// MultiModal Message is not supported")]; + } + + private IEnumerable ProcessIncomingMessagesForSelf(ToolCallMessage message) + { + var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments)); + var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty); + foreach (var tc in toolCall) + { + chatRequestMessage.ToolCalls.Add(tc); + } + + return new[] { chatRequestMessage }; + } + + private IEnumerable ProcessIncomingMessagesForSelf(ToolCallResultMessage message) + { + return message.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); + } + + private IEnumerable ProcessIncomingMessagesForSelf(Message message) + { + if (message.Role == Role.System) + { + return new[] { new ChatRequestSystemMessage(message.Content) }; + } + else if (message.Content is string content && content is { Length: > 0 }) + { + if (message.FunctionName is null) + { + return new[] { new ChatRequestAssistantMessage(message.Content) }; + } + else + { + return new[] { new ChatRequestFunctionMessage(message.FunctionName, content) }; + } + } + else if (message.FunctionName is string functionName) + { + return new[] + { + new ChatRequestAssistantMessage(string.Empty) + { + FunctionCall = new FunctionCall(functionName, message.FunctionArguments) + } + }; + } + else + { + throw new InvalidOperationException("Invalid Message as message from self."); + } + } + + private IEnumerable ProcessIncomingMessagesForSelf(IMessage message) + { + return new[] { message.Content }; + } + + private IEnumerable ProcessIncomingMessagesForSelf(AggregateMessage aggregateMessage) + { + var toolCallMessage1 = aggregateMessage.Message1; + var toolCallResultMessage = aggregateMessage.Message2; + + var assistantMessage = new ChatRequestAssistantMessage(string.Empty); + var toolCalls = toolCallMessage1.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments)); + foreach (var tc in toolCalls) + { + assistantMessage.ToolCalls.Add(tc); + } + + var toolCallResults = toolCallResultMessage.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); + + // return assistantMessage and tool call result messages + var messages = new List { assistantMessage }; + messages.AddRange(toolCallResults); + + return messages; + } + + private IEnumerable ProcessIncomingMessagesForOther(TextMessage message) + { + if (message.Role == Role.System) + { + return new[] { new ChatRequestSystemMessage(message.Content) }; + } + else + { + return new[] { new ChatRequestUserMessage(message.Content) }; + } + } + + private IEnumerable ProcessIncomingMessagesForOther(ImageMessage message) + { + return new[] { new ChatRequestUserMessage([ + new ChatMessageImageContentItem(new Uri(message.Url)), + ])}; + } + + private IEnumerable ProcessIncomingMessagesForOther(MultiModalMessage message) + { + IEnumerable items = message.Content.Select(ci => ci switch + { + TextMessage text => new ChatMessageTextContentItem(text.Content), + ImageMessage image => new ChatMessageImageContentItem(new Uri(image.Url)), + _ => throw new NotImplementedException(), + }); + + return new[] { new ChatRequestUserMessage(items) }; + } + + private IEnumerable ProcessIncomingMessagesForOther(ToolCallMessage _) + { + return [new ChatRequestUserMessage("// ToolCall Message Type is not supported")]; + } + + private IEnumerable ProcessIncomingMessagesForOther(ToolCallResultMessage message) + { + return message.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result)); + } + + private IEnumerable ProcessIncomingMessagesForOther(Message message) + { + if (message.Role == Role.System) + { + return new[] { new ChatRequestSystemMessage(message.Content) }; + } + else if (message.Content is string content && content is { Length: > 0 }) + { + if (message.FunctionName is not null) + { + return new[] { new ChatRequestToolMessage(content, message.FunctionName) }; + } + + return new[] { new ChatRequestUserMessage(message.Content) }; + } + else if (message.FunctionName is string _) + { + return new[] + { + new ChatRequestUserMessage("// Message type is not supported"), + }; + } + else + { + throw new InvalidOperationException("Invalid Message as message from other."); + } + } + + private IEnumerable ProcessIncomingMessagesForOther(IMessage message) + { + return new[] { message.Content }; + } + + private IEnumerable ProcessIncomingMessagesForOther(AggregateMessage aggregateMessage) + { + // convert as user message + var resultMessage = aggregateMessage.Message2; + + return ProcessIncomingMessagesForOther(resultMessage); + } + + private IEnumerable ProcessIncomingMessagesWithEmptyFrom(TextMessage message) + { + return ProcessIncomingMessagesForOther(message); + } + + private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ImageMessage message) + { + return ProcessIncomingMessagesForOther(message); + } + + private IEnumerable ProcessIncomingMessagesWithEmptyFrom(MultiModalMessage message) + { + return ProcessIncomingMessagesForOther(message); + } + + private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ToolCallMessage message) + { + return ProcessIncomingMessagesForSelf(message); + } + + private IEnumerable ProcessIncomingMessagesWithEmptyFrom(ToolCallResultMessage message) + { + return ProcessIncomingMessagesForOther(message); + } + + private IEnumerable ProcessIncomingMessagesWithEmptyFrom(Message message) + { + return ProcessIncomingMessagesForOther(message); + } + + private IEnumerable ProcessIncomingMessagesWithEmptyFrom(IMessage message) + { + return new[] { message.Content }; + } + + private IEnumerable ProcessIncomingMessagesWithEmptyFrom(AggregateMessage aggregateMessage) + { + return ProcessIncomingMessagesForOther(aggregateMessage); + } +} diff --git a/dotnet/test/AutoGen.Tests/EchoAgent.cs b/dotnet/test/AutoGen.Tests/EchoAgent.cs index 01210e1eae2..28a7b91bad5 100644 --- a/dotnet/test/AutoGen.Tests/EchoAgent.cs +++ b/dotnet/test/AutoGen.Tests/EchoAgent.cs @@ -16,8 +16,8 @@ public EchoAgent(string name) } public string Name { get; } - public Task GenerateReplyAsync( - IEnumerable conversation, + public Task GenerateReplyAsync( + IEnumerable conversation, GenerateReplyOptions? options = null, CancellationToken ct = default) { @@ -25,7 +25,7 @@ public EchoAgent(string name) var lastMessage = conversation.Last(); lastMessage.From = this.Name; - return Task.FromResult(lastMessage); + return Task.FromResult(lastMessage); } } } diff --git a/dotnet/test/AutoGen.Tests/MathClassTest.cs b/dotnet/test/AutoGen.Tests/MathClassTest.cs index f64efee393b..367d31668ce 100644 --- a/dotnet/test/AutoGen.Tests/MathClassTest.cs +++ b/dotnet/test/AutoGen.Tests/MathClassTest.cs @@ -92,7 +92,7 @@ public async Task AssistantAgentMathChatTestAsync() var formattedMessage = reply.FormatMessage(); this._output.WriteLine(formattedMessage); - if (reply.Content?.Contains("[UPDATE_PROGRESS]") is true) + if (reply.GetContent()?.Contains("[UPDATE_PROGRESS]") is true) { return reply; } @@ -219,17 +219,17 @@ private async Task RunMathChatAsync(IAgent teacher, IAgent student, IAgent admin } // check if there's five questions from teacher - chatHistory.Where(msg => msg.From == teacher.Name && msg.Content?.Contains("[MATH_QUESTION]") is true) + chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[MATH_QUESTION]") is true) .Count() .Should().BeGreaterThanOrEqualTo(5); // check if there's more than five answers from student (answer might be wrong) - chatHistory.Where(msg => msg.From == student.Name && msg.Content?.Contains("[MATH_ANSWER]") is true) + chatHistory.Where(msg => msg.From == student.Name && msg.GetContent()?.Contains("[MATH_ANSWER]") is true) .Count() .Should().BeGreaterThanOrEqualTo(5); // check if there's five answer_is_correct from teacher - chatHistory.Where(msg => msg.From == teacher.Name && msg.Content?.Contains("[ANSWER_IS_CORRECT]") is true) + chatHistory.Where(msg => msg.From == teacher.Name && msg.GetContent()?.Contains("[ANSWER_IS_CORRECT]") is true) .Count() .Should().BeGreaterThanOrEqualTo(5); diff --git a/dotnet/test/AutoGen.Tests/MiddlewareAgentTest.cs b/dotnet/test/AutoGen.Tests/MiddlewareAgentTest.cs index 405d91b1947..dfb90324c09 100644 --- a/dotnet/test/AutoGen.Tests/MiddlewareAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/MiddlewareAgentTest.cs @@ -21,39 +21,39 @@ public async Task MiddlewareAgentUseTestAsync() // the reply should be the same as the original agent middlewareAgent.Name.Should().Be("echo"); var reply = await middlewareAgent.SendAsync("hello"); - reply.Content.Should().Be("hello"); + reply.GetContent().Should().Be("hello"); middlewareAgent.Use(async (messages, options, agent, ct) => { - var lastMessage = messages.Last(); - lastMessage.Content = $"[middleware 0] {lastMessage.Content}"; + var lastMessage = messages.Last() as TextMessage; + lastMessage!.Content = $"[middleware 0] {lastMessage.Content}"; return await agent.GenerateReplyAsync(messages, options, ct); }); reply = await middlewareAgent.SendAsync("hello"); - reply.Content.Should().Be("[middleware 0] hello"); + reply.GetContent().Should().Be("[middleware 0] hello"); middlewareAgent.Use(async (messages, options, agent, ct) => { - var lastMessage = messages.Last(); - lastMessage.Content = $"[middleware 1] {lastMessage.Content}"; + var lastMessage = messages.Last() as TextMessage; + lastMessage!.Content = $"[middleware 1] {lastMessage.Content}"; return await agent.GenerateReplyAsync(messages, options, ct); }); // when multiple middleware are added, they will be executed in LIFO order reply = await middlewareAgent.SendAsync("hello"); - reply.Content.Should().Be("[middleware 0] [middleware 1] hello"); + reply.GetContent().Should().Be("[middleware 0] [middleware 1] hello"); // test short cut // short cut middleware will not call next middleware middlewareAgent.Use(async (messages, options, next, ct) => { - var lastMessage = messages.Last(); - lastMessage.Content = $"[middleware shortcut] {lastMessage.Content}"; + var lastMessage = messages.Last() as TextMessage; + lastMessage!.Content = $"[middleware shortcut] {lastMessage.Content}"; return lastMessage; }); reply = await middlewareAgent.SendAsync("hello"); - reply.Content.Should().Be("[middleware shortcut] hello"); + reply.GetContent().Should().Be("[middleware shortcut] hello"); } [Fact] @@ -64,38 +64,38 @@ public async Task RegisterMiddlewareTestAsync() // RegisterMiddleware will return a new agent and keep the original agent unchanged var middlewareAgent = echoAgent.RegisterMiddleware(async (messages, options, agent, ct) => { - var lastMessage = messages.Last(); - lastMessage.Content = $"[middleware 0] {lastMessage.Content}"; + var lastMessage = messages.Last() as TextMessage; + lastMessage!.Content = $"[middleware 0] {lastMessage.Content}"; return await agent.GenerateReplyAsync(messages, options, ct); }); var reply = await middlewareAgent.SendAsync("hello"); - reply.Content.Should().Be("[middleware 0] hello"); + reply.GetContent().Should().Be("[middleware 0] hello"); reply = await echoAgent.SendAsync("hello"); - reply.Content.Should().Be("hello"); + reply.GetContent().Should().Be("hello"); // when multiple middleware are added, they will be executed in LIFO order middlewareAgent = middlewareAgent.RegisterMiddleware(async (messages, options, agent, ct) => { - var lastMessage = messages.Last(); - lastMessage.Content = $"[middleware 1] {lastMessage.Content}"; + var lastMessage = messages.Last() as TextMessage; + lastMessage!.Content = $"[middleware 1] {lastMessage.Content}"; return await agent.GenerateReplyAsync(messages, options, ct); }); reply = await middlewareAgent.SendAsync("hello"); - reply.Content.Should().Be("[middleware 0] [middleware 1] hello"); + reply.GetContent().Should().Be("[middleware 0] [middleware 1] hello"); // test short cut // short cut middleware will not call next middleware middlewareAgent = middlewareAgent.RegisterMiddleware(async (messages, options, agent, ct) => { - var lastMessage = messages.Last(); - lastMessage.Content = $"[middleware shortcut] {lastMessage.Content}"; + var lastMessage = messages.Last() as TextMessage; + lastMessage!.Content = $"[middleware shortcut] {lastMessage.Content}"; return lastMessage; }); reply = await middlewareAgent.SendAsync("hello"); - reply.Content.Should().Be("[middleware shortcut] hello"); + reply.GetContent().Should().Be("[middleware shortcut] hello"); middlewareAgent.Middlewares.Count().Should().Be(3); } diff --git a/dotnet/test/AutoGen.Tests/MiddlewareTest.cs b/dotnet/test/AutoGen.Tests/MiddlewareTest.cs index 712157a030f..050ee3df34f 100644 --- a/dotnet/test/AutoGen.Tests/MiddlewareTest.cs +++ b/dotnet/test/AutoGen.Tests/MiddlewareTest.cs @@ -29,7 +29,7 @@ public async Task HumanInputMiddlewareTestAsync() var neverInputAgent = agent.RegisterMiddleware(neverAskUserInputMW); var reply = await neverInputAgent.SendAsync("hello"); - reply.Content!.Should().Be("hello"); + reply.GetContent()!.Should().Be("hello"); reply.From.Should().Be("echo"); var alwaysAskUserInputMW = new HumanInputMiddleware( @@ -38,28 +38,28 @@ public async Task HumanInputMiddlewareTestAsync() var alwaysInputAgent = agent.RegisterMiddleware(alwaysAskUserInputMW); reply = await alwaysInputAgent.SendAsync("hello"); - reply.Content!.Should().Be("input"); + reply.GetContent()!.Should().Be("input"); reply.From.Should().Be("echo"); // test auto mode // if the reply from echo is not terminate message, return the original reply var autoAskUserInputMW = new HumanInputMiddleware( mode: HumanInputMode.AUTO, - isTermination: async (messages, ct) => messages.Last()?.Content == "terminate", + isTermination: async (messages, ct) => messages.Last()?.GetContent() == "terminate", getInput: () => "input", exitKeyword: "exit"); var autoInputAgent = agent.RegisterMiddleware(autoAskUserInputMW); reply = await autoInputAgent.SendAsync("hello"); - reply.Content!.Should().Be("hello"); + reply.GetContent()!.Should().Be("hello"); // if the reply from echo is terminate message, asking user for input reply = await autoInputAgent.SendAsync("terminate"); - reply.Content!.Should().Be("input"); + reply.GetContent()!.Should().Be("input"); // if the reply from echo is terminate message, and user input is exit, return the TERMINATE message autoAskUserInputMW = new HumanInputMiddleware( mode: HumanInputMode.AUTO, - isTermination: async (messages, ct) => messages.Last().Content == "terminate", + isTermination: async (messages, ct) => messages.Last().GetContent() == "terminate", getInput: () => "exit", exitKeyword: "exit"); autoInputAgent = agent.RegisterMiddleware(autoAskUserInputMW); @@ -93,7 +93,7 @@ public async Task FunctionCallMiddlewareTestAsync() var testAgent = agent.RegisterMiddleware(mw); var functionCallMessage = new Message(Role.User, content: null, from: "user", functionCall: functionCall); var reply = await testAgent.SendAsync(functionCallMessage); - reply.Content!.Should().Be("[FUNC] hello"); + reply.GetContent()!.Should().Be("[FUNC] hello"); reply.From.Should().Be("echo"); // test 2 @@ -103,7 +103,7 @@ public async Task FunctionCallMiddlewareTestAsync() functionMap: new Dictionary>> { { "echo", EchoWrapper } }); testAgent = functionCallAgent.RegisterMiddleware(mw); reply = await testAgent.SendAsync("hello"); - reply.Content!.Should().Be("[FUNC] hello"); + reply.GetContent()!.Should().Be("[FUNC] hello"); reply.From.Should().Be("echo"); // test 3 @@ -112,7 +112,7 @@ public async Task FunctionCallMiddlewareTestAsync() functionMap: new Dictionary>> { { "echo", EchoWrapper } }); testAgent = agent.RegisterMiddleware(mw); reply = await testAgent.SendAsync("hello"); - reply.Content!.Should().Be("hello"); + reply.GetContent()!.Should().Be("hello"); reply.From.Should().Be("echo"); // test 4 @@ -122,6 +122,6 @@ public async Task FunctionCallMiddlewareTestAsync() testAgent = agent.RegisterMiddleware(mw); functionCallMessage = new Message(Role.User, content: null, from: "user", functionCall: functionCall); reply = await testAgent.SendAsync(functionCallMessage); - reply.Content!.Should().Be("Function echo is not available. Available functions are: echo2"); + reply.GetContent()!.Should().Be("Function echo is not available. Available functions are: echo2"); } } diff --git a/dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs b/dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs index a090b31c52b..d4866ad8736 100644 --- a/dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs @@ -14,12 +14,13 @@ public async Task RegisterReplyTestAsync() { IAgent echoAgent = new EchoAgent("echo"); echoAgent = echoAgent - .RegisterReply(async (conversations, ct) => new Message(Role.Assistant, "I'm your father", from: echoAgent.Name)); + .RegisterReply(async (conversations, ct) => new TextMessage(Role.Assistant, "I'm your father", from: echoAgent.Name)); var msg = new Message(Role.User, "hey"); var reply = await echoAgent.SendAsync(msg); - reply.Content.Should().Be("I'm your father"); - reply.Role.Should().Be(Role.Assistant); + reply.Should().BeOfType(); + reply.GetContent().Should().Be("I'm your father"); + reply.GetRole().Should().Be(Role.Assistant); reply.From.Should().Be("echo"); } } diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index d0f4fd5234c..0c8f00f00e3 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -77,9 +77,10 @@ public async Task GPTAgentVisionTestAsync() response.From.Should().Be(visionAgent.Name); var labelResponse = await gpt3Agent.SendAsync(response); + labelResponse.Should().BeOfType(); labelResponse.From.Should().Be(gpt3Agent.Name); - labelResponse.Content.Should().Be("[HIGHEST_LABEL] gpt-4 (n=5) green"); - labelResponse.FunctionName.Should().Be(nameof(GetHighestLabel)); + labelResponse.GetContent().Should().Be("[HIGHEST_LABEL] gpt-4 (n=5) green"); + labelResponse.GetToolCalls().First().FunctionName.Should().Be(nameof(GetHighestLabel)); } [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] @@ -129,8 +130,8 @@ public async Task AssistantAgentDefaultReplyTestAsync() var reply = await assistantAgent.SendAsync("hi"); - reply.Content.Should().Be("hello world"); - reply.Role.Should().Be(Role.Assistant); + reply.GetContent().Should().Be("hello world"); + reply.GetRole().Should().Be(Role.Assistant); reply.From.Should().Be(assistantAgent.Name); } @@ -209,9 +210,9 @@ private async Task EchoFunctionCallTestAsync(IAgent agent) var reply = await agent.SendAsync(chatHistory: new Message[] { message, helloWorld }); - reply.Role.Should().Be(Role.Assistant); + reply.GetRole().Should().Be(Role.Assistant); reply.From.Should().Be(agent.Name); - reply.FunctionName.Should().Be(nameof(EchoAsync)); + reply.GetToolCalls().First().FunctionName.Should().Be(nameof(EchoAsync)); } private async Task EchoFunctionCallExecutionTestAsync(IAgent agent) @@ -221,10 +222,10 @@ private async Task EchoFunctionCallExecutionTestAsync(IAgent agent) var reply = await agent.SendAsync(chatHistory: new Message[] { message, helloWorld }); - reply.Content.Should().Be("[ECHO] Hello world"); - reply.Role.Should().Be(Role.Assistant); + reply.GetContent().Should().Be("[ECHO] Hello world"); + reply.GetRole().Should().Be(Role.Assistant); reply.From.Should().Be(agent.Name); - reply.FunctionName.Should().Be(nameof(EchoAsync)); + reply.GetToolCalls().First().FunctionName.Should().Be(nameof(EchoAsync)); } private async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent) @@ -263,8 +264,8 @@ private async Task UpperCaseTest(IAgent agent) var reply = await agent.SendAsync(chatHistory: new Message[] { message, uppCaseMessage }); - reply.Content.Should().Be("ABCDEFG"); - reply.Role.Should().Be(Role.Assistant); + reply.GetContent().Should().Be("ABCDEFG"); + reply.GetRole().Should().Be(Role.Assistant); reply.From.Should().Be(agent.Name); } diff --git a/dotnet/test/AutoGen.Tests/TwoAgentTest.cs b/dotnet/test/AutoGen.Tests/TwoAgentTest.cs index f7531efdc21..a267374efab 100644 --- a/dotnet/test/AutoGen.Tests/TwoAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/TwoAgentTest.cs @@ -61,7 +61,7 @@ public async Task TwoAgentWeatherChatTestAsync() .RegisterMiddleware(async (msgs, option, agent, ct) => { var lastMessage = msgs.Last(); - if (lastMessage.FunctionName != null) + if (lastMessage.GetToolCalls().FirstOrDefault()?.FunctionName != null) { return await agent.GenerateReplyAsync(msgs, option, ct); } @@ -86,7 +86,7 @@ public async Task TwoAgentWeatherChatTestAsync() chatHistory.Last().IsGroupChatTerminateMessage().Should().BeTrue(); // the third last message should be the weather message from function - chatHistory[^3].Content.Should().Be("[GetWeatherFunction] The weather in New York is sunny"); + chatHistory[^3].GetContent().Should().Be("[GetWeatherFunction] The weather in New York is sunny"); // the # of messages should be 5 chatHistory.Length.Should().Be(5); diff --git a/dotnet/test/AutoGen.Tests/WorkflowTest.cs b/dotnet/test/AutoGen.Tests/WorkflowTest.cs index ff09835660d..65806f6dfd1 100644 --- a/dotnet/test/AutoGen.Tests/WorkflowTest.cs +++ b/dotnet/test/AutoGen.Tests/WorkflowTest.cs @@ -19,7 +19,7 @@ public async Task TransitionTestAsync() var aliceToBob = Transition.Create(alice, bob, async (from, to, messages) => { - if (messages.Any(m => m is { Content: "Hello" })) + if (messages.Any(m => m.GetContent() == "Hello")) { return true; } From 5194e23a3ff377be5cc37a7bb2adb8edaf610c2c Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 29 Feb 2024 15:42:39 -0800 Subject: [PATCH 11/27] more updates --- .../Example02_TwoAgent_MathChat.cs | 5 +- .../Example05_Dalle_And_GPT4V.cs | 9 +-- ...7_Dynamic_GroupChat_Calculate_Fibonacci.cs | 4 +- .../Example09_LMStudio_FunctionCall.cs | 8 +-- .../Extension/AgentExtension.cs | 2 +- .../Core/Agent/MiddlewareStreamingAgent.cs | 2 +- .../AutoGen/Core/Extension/AgentExtension.cs | 2 +- .../Core/Extension/MessageExtension.cs | 2 +- .../Core/Middleware/FunctionCallMiddleware.cs | 56 ++++++++++--------- .../Core/Middleware/HumanInputMiddleware.cs | 8 +-- dotnet/test/AutoGen.Tests/MiddlewareTest.cs | 6 +- 11 files changed, 51 insertions(+), 53 deletions(-) diff --git a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs index f847ed6d69f..210362d54f5 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs @@ -29,10 +29,7 @@ public static async Task RunAsync() { if (reply.GetContent()?.Contains("TERMINATE") is true) { - return new Message(Role.Assistant, GroupChatExtension.TERMINATE) - { - From = reply.From, - }; + return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, from: reply.From); } return reply; diff --git a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs index 428677885c1..50525b1e772 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs @@ -126,10 +126,7 @@ public static async Task RunAsync() // if no image is generated, then ask DALL-E agent to generate image if (msgs.Last() is not ImageMessage) { - return new Message(Role.Assistant, "Hey dalle, please generate image") - { - From = "gpt4v", - }; + return new TextMessage(Role.Assistant, "Hey dalle, please generate image", from: "gpt4v"); } return null; @@ -153,9 +150,9 @@ public static async Task RunAsync() }); }).RegisterPrintFormatMessageHook(); - IEnumerable conversation = new List() + IEnumerable conversation = new List() { - new Message(Role.User, "Hey dalle, please generate image from prompt: English short hair blue cat chase after a mouse") + new TextMessage(Role.User, "Hey dalle, please generate image from prompt: English short hair blue cat chase after a mouse") }; var maxRound = 20; await gpt4VAgent.InitiateChatAsync( diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs index 72c15ae035f..dfe8d201f60 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs @@ -188,7 +188,7 @@ public static async Task RunAsync() { if (msgs.Count() == 0) { - return new Message(Role.Assistant, "No code available. Coder please write code"); + return new TextMessage(Role.Assistant, "No code available. Coder please write code"); } return null; @@ -199,7 +199,7 @@ public static async Task RunAsync() var coderMsg = msgs.LastOrDefault(msg => msg.From == "coder"); if (coderMsg is null) { - return Enumerable.Empty(); + return Enumerable.Empty(); } else { diff --git a/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs index 4136c5639c6..b73307039c4 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example09_LMStudio_FunctionCall.cs @@ -83,8 +83,8 @@ public static async Task RunAsync() .RegisterMiddleware(async (msgs, option, innerAgent, ct) => { // inject few-shot example to the message - var exampleGetWeather = new Message(Role.User, "Get weather in London"); - var exampleAnswer = new Message(Role.Assistant, "{\n \"name\": \"GetWeather\",\n \"arguments\": {\n \"city\": \"London\"\n }\n}", from: innerAgent.Name); + var exampleGetWeather = new TextMessage(Role.User, "Get weather in London"); + var exampleAnswer = new TextMessage(Role.Assistant, "{\n \"name\": \"GetWeather\",\n \"arguments\": {\n \"city\": \"London\"\n }\n}", from: innerAgent.Name); msgs = new[] { exampleGetWeather, exampleAnswer }.Concat(msgs).ToArray(); var reply = await innerAgent.GenerateReplyAsync(msgs, option, ct); @@ -100,12 +100,12 @@ public static async Task RunAsync() if (functionCall.Name == instance.GetWeatherFunction.Name) { var result = await instance.GetWeatherWrapper(arguments); - return new Message(Role.Assistant, result); + return new TextMessage(Role.Assistant, result); } else if (functionCall.Name == instance.GoogleSearchFunction.Name) { var result = await instance.GoogleSearchWrapper(arguments); - return new Message(Role.Assistant, result); + return new TextMessage(Role.Assistant, result); } else { diff --git a/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs b/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs index b0efcee2dba..da578fc20b2 100644 --- a/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs +++ b/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs @@ -73,7 +73,7 @@ public static class AgentExtension maximumOutputToKeep = result.Length; } - return new Message(Role.Assistant, result.ToString().Substring(0, maximumOutputToKeep), from: agent.Name); + return new TextMessage(Role.Assistant, result.ToString().Substring(0, maximumOutputToKeep), from: agent.Name); }); } } diff --git a/dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs index 92536b3259e..4684e15fbb1 100644 --- a/dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs @@ -84,7 +84,7 @@ public Task> GenerateStreamingReplyAsync(IEnumerable< { // TODO // fix this - var context = new MiddlewareContext((IEnumerable)messages, options); + var context = new MiddlewareContext(messages, options); return middleware.InvokeAsync(context, innerAgent, cancellationToken); } } diff --git a/dotnet/src/AutoGen/Core/Extension/AgentExtension.cs b/dotnet/src/AutoGen/Core/Extension/AgentExtension.cs index a2463674b73..9bb10fa9d7b 100644 --- a/dotnet/src/AutoGen/Core/Extension/AgentExtension.cs +++ b/dotnet/src/AutoGen/Core/Extension/AgentExtension.cs @@ -156,7 +156,7 @@ public static class AgentExtension CancellationToken ct = default) { var chatMessage = new TextMessage(Role.Assistant, msg, from: agent.Name); - chatHistory = chatHistory ?? Enumerable.Empty(); + chatHistory = chatHistory ?? Enumerable.Empty(); chatHistory = chatHistory.Append(chatMessage); return await agent.SendMessageToGroupAsync(groupChat, chatHistory, maxRound, ct); diff --git a/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs b/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs index da84ab28d23..2d36b58a0a5 100644 --- a/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs +++ b/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs @@ -96,7 +96,7 @@ public static bool IsSystemMessage(this IMessage message) }; } - public static IEnumerable GetToolCalls(this IMessage message) + public static IList GetToolCalls(this IMessage message) { return message switch { diff --git a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs index e2e00c1e684..c56431a6389 100644 --- a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs +++ b/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs @@ -33,31 +33,31 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, { // if the last message is a function call message, invoke the function and return the result instead of sending to the agent. var lastMessage = context.Messages.Last(); - if (lastMessage is Message msg && msg is { Content: null, FunctionName: string functionName, FunctionArguments: string functionArguments }) + if (lastMessage.GetToolCalls() is IList toolCalls && toolCalls.Count() == 1) { - if (this.functionMap?.TryGetValue(functionName, out var func) is true) + var toolCallResult = new List(); + foreach (var toolCall in toolCalls) { - var result = await func(functionArguments); - return new Message(role: Role.Function, content: result, from: agent.Name) + var functionName = toolCall.FunctionName; + var functionArguments = toolCall.FunctionArguments; + if (this.functionMap?.TryGetValue(functionName, out var func) is true) { - FunctionName = functionName, - FunctionArguments = functionArguments, - }; - } - else if (this.functionMap is not null) - { - var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}"; + var result = await func(functionArguments); + toolCallResult.Add(new ToolCall(functionName, functionArguments, result)); + } + else if (this.functionMap is not null) + { + var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}"; - return new Message(role: Role.Function, content: errorMessage, from: agent.Name) + toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage)); + } + else { - FunctionName = functionName, - FunctionArguments = functionArguments, - }; - } - else - { - throw new InvalidOperationException("FunctionMap is not available"); + throw new InvalidOperationException("FunctionMap is not available"); + } } + + return new ToolCallResultMessage(toolCallResult, from: agent.Name); } // combine functions @@ -68,17 +68,21 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, var reply = await agent.GenerateReplyAsync(context.Messages, options, cancellationToken); // if the reply is a function call message plus the function's name is available in function map, invoke the function and return the result instead of sending to the agent. - if (reply is Message message && message is { FunctionName: string fName, FunctionArguments: string fArgs }) + if (reply.GetToolCalls() is IList toolCallsReply && toolCallsReply.Count() == 1) { - if (this.functionMap?.TryGetValue(fName, out var func) is true) + var toolCallResult = new List(); + foreach (var toolCall in toolCallsReply) { - var result = await func(fArgs); - return new Message(role: Role.Assistant, content: result, from: reply.From) + var fName = toolCall.FunctionName; + var fArgs = toolCall.FunctionArguments; + if (this.functionMap?.TryGetValue(fName, out var func) is true) { - FunctionName = fName, - FunctionArguments = fArgs, - }; + var result = await func(fArgs); + toolCallResult.Add(new ToolCall(fName, fArgs, result)); + } } + + return new ToolCallResultMessage(toolCallResult, from: agent.Name); } // for all other messages, just return the reply from the agent. diff --git a/dotnet/src/AutoGen/Core/Middleware/HumanInputMiddleware.cs b/dotnet/src/AutoGen/Core/Middleware/HumanInputMiddleware.cs index 9295593f55a..7bfd29c3560 100644 --- a/dotnet/src/AutoGen/Core/Middleware/HumanInputMiddleware.cs +++ b/dotnet/src/AutoGen/Core/Middleware/HumanInputMiddleware.cs @@ -54,10 +54,10 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, var input = getInput(); if (input == exitKeyword) { - return new Message(Role.Assistant, GroupChatExtension.TERMINATE, agent.Name); + return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, agent.Name); } - return new Message(Role.Assistant, input, agent.Name); + return new TextMessage(Role.Assistant, input, agent.Name); } // if the mode is auto, then prompt the user for input if the message is not a termination message @@ -72,10 +72,10 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, var input = getInput(); if (input == exitKeyword) { - return new Message(Role.Assistant, GroupChatExtension.TERMINATE, agent.Name); + return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, agent.Name); } - return new Message(Role.Assistant, input, agent.Name); + return new TextMessage(Role.Assistant, input, agent.Name); } throw new InvalidOperationException("Invalid mode"); diff --git a/dotnet/test/AutoGen.Tests/MiddlewareTest.cs b/dotnet/test/AutoGen.Tests/MiddlewareTest.cs index 050ee3df34f..5fbbf44f00a 100644 --- a/dotnet/test/AutoGen.Tests/MiddlewareTest.cs +++ b/dotnet/test/AutoGen.Tests/MiddlewareTest.cs @@ -82,7 +82,7 @@ public async Task FunctionCallMiddlewareTestAsync() return await agent.GenerateReplyAsync(messages, options, ct); } - return new Message(Role.Assistant, content: null, from: agent.Name, functionCall: functionCall); + return new ToolCallMessage(functionCall.Name, functionCall.Arguments, from: agent.Name); }); // test 1 @@ -91,8 +91,9 @@ public async Task FunctionCallMiddlewareTestAsync() functionMap: new Dictionary>> { { "echo", EchoWrapper } }); var testAgent = agent.RegisterMiddleware(mw); - var functionCallMessage = new Message(Role.User, content: null, from: "user", functionCall: functionCall); + var functionCallMessage = new ToolCallMessage(functionCall.Name, functionCall.Arguments, from: "user"); var reply = await testAgent.SendAsync(functionCallMessage); + reply.Should().BeOfType(); reply.GetContent()!.Should().Be("[FUNC] hello"); reply.From.Should().Be("echo"); @@ -120,7 +121,6 @@ public async Task FunctionCallMiddlewareTestAsync() mw = new FunctionCallMiddleware( functionMap: new Dictionary>> { { "echo2", EchoWrapper } }); testAgent = agent.RegisterMiddleware(mw); - functionCallMessage = new Message(Role.User, content: null, from: "user", functionCall: functionCall); reply = await testAgent.SendAsync(functionCallMessage); reply.GetContent()!.Should().Be("Function echo is not available. Available functions are: echo2"); } From da89fb0fde929568c8c29da32315177a2b8c4a6a Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 29 Feb 2024 17:13:10 -0800 Subject: [PATCH 12/27] update --- .../Core/Extension/MessageExtension.cs | 11 +- .../Core/Extension/MiddlewareExtension.cs | 8 +- .../Core/Middleware/FunctionCallMiddleware.cs | 11 +- dotnet/src/AutoGen/OpenAI/GPTAgent.cs | 114 +++++++++++++++++- .../Middleware/OpenAIMessageConnector.cs | 18 +-- .../test/AutoGen.Tests/OpenAIMessageTests.cs | 39 +++--- dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 27 ++--- dotnet/test/AutoGen.Tests/TwoAgentTest.cs | 2 +- 8 files changed, 173 insertions(+), 57 deletions(-) diff --git a/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs b/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs index 2d36b58a0a5..35a0c1987da 100644 --- a/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs +++ b/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs @@ -96,7 +96,12 @@ public static bool IsSystemMessage(this IMessage message) }; } - public static IList GetToolCalls(this IMessage message) + /// + /// Return the tool calls from the message if it's available. + /// + /// + /// + public static IList? GetToolCalls(this IMessage message) { return message switch { @@ -105,8 +110,8 @@ public static IList GetToolCalls(this IMessage message) Message msg => msg.FunctionName is not null && msg.FunctionArguments is not null ? msg.Content is not null ? new List { new ToolCall(msg.FunctionName, msg.FunctionArguments, result: msg.Content) } : new List { new ToolCall(msg.FunctionName, msg.FunctionArguments) } - : [], - _ => [], + : null, + _ => null, }; } } diff --git a/dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs b/dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs index a3314d5a368..d5223111460 100644 --- a/dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs +++ b/dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs @@ -160,7 +160,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// /// Register a middleware to an existing agent and return a new agent with the middleware. /// - public static MiddlewareStreamingAgent RegisterMiddleware( + public static MiddlewareStreamingAgent RegisterStreamingMiddleware( this TAgent agent, IStreamingMiddleware middleware) where TAgent : IStreamingAgent @@ -174,7 +174,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// /// Register a middleware to an existing agent and return a new agent with the middleware. /// - public static MiddlewareStreamingAgent RegisterMiddleware( + public static MiddlewareStreamingAgent RegisterStreamingMiddleware( this MiddlewareStreamingAgent agent, IStreamingMiddleware middleware) where TAgent : IStreamingAgent @@ -189,7 +189,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// /// Register a middleware to an existing agent and return a new agent with the middleware. /// - public static MiddlewareStreamingAgent RegisterMiddleware( + public static MiddlewareStreamingAgent RegisterStreamingMiddleware( this TAgent agent, Func>> func, string? middlewareName = null) @@ -204,7 +204,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// /// Register a middleware to an existing agent and return a new agent with the middleware. /// - public static MiddlewareStreamingAgent RegisterMiddleware( + public static MiddlewareStreamingAgent RegisterStreamingMiddleware( this MiddlewareStreamingAgent agent, Func>> func, string? middlewareName = null) diff --git a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs index c56431a6389..9445f9180a2 100644 --- a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs +++ b/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs @@ -68,7 +68,7 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, var reply = await agent.GenerateReplyAsync(context.Messages, options, cancellationToken); // if the reply is a function call message plus the function's name is available in function map, invoke the function and return the result instead of sending to the agent. - if (reply.GetToolCalls() is IList toolCallsReply && toolCallsReply.Count() == 1) + if (reply.GetToolCalls() is IList toolCallsReply) { var toolCallResult = new List(); foreach (var toolCall in toolCallsReply) @@ -82,7 +82,14 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, } } - return new ToolCallResultMessage(toolCallResult, from: agent.Name); + if (toolCallResult.Count() > 0) + { + return new ToolCallResultMessage(toolCallResult, from: agent.Name); + } + else + { + return reply; + } } // for all other messages, just return the reply from the agent. diff --git a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs index a5df636935d..f3e2f5bf949 100644 --- a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs +++ b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs @@ -6,6 +6,8 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using AutoGen.Core.Middleware; +using AutoGen.OpenAI.Middleware; using Azure.AI.OpenAI; namespace AutoGen.OpenAI; @@ -19,7 +21,7 @@ public class GPTAgent : IStreamingAgent private readonly IDictionary>>? functionMap; private readonly OpenAIClient openAIClient; private readonly string? modelName; - public const string CHUNK_KEY = "oai_msg_chunk"; + private readonly OpenAIClientAgent _innerAgent; public GPTAgent( string name, @@ -44,6 +46,7 @@ public class GPTAgent : IStreamingAgent _ => throw new ArgumentException($"Unsupported config type {config.GetType()}"), }; + _innerAgent = new OpenAIClientAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); _systemMessage = systemMessage; _functions = functions; Name = name; @@ -70,6 +73,7 @@ public class GPTAgent : IStreamingAgent _temperature = temperature; _maxTokens = maxTokens; this.functionMap = functionMap; + _innerAgent = new OpenAIClientAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); } public string Name { get; } @@ -79,11 +83,22 @@ public class GPTAgent : IStreamingAgent GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - var settings = this.CreateChatCompletionsOptions(options, messages); - var response = await this.openAIClient.GetChatCompletionsAsync(settings, cancellationToken); - var oaiMessage = response.Value.Choices.First().Message; + var oaiConnectorMiddleware = new OpenAIMessageConnector(); + var agent = this._innerAgent.RegisterMiddleware(oaiConnectorMiddleware); + if (this.functionMap is not null) + { + var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap); + agent = agent.RegisterMiddleware(functionMapMiddleware); + } - return await this.PostProcessMessage(oaiMessage); + var reply = await agent.GenerateReplyAsync(messages, options, cancellationToken); + + if (reply is IMessage oaiMessage) + { + return await this.PostProcessMessage(oaiMessage.Content); + } + + throw new Exception("Invalid message type"); } public async Task> GenerateStreamingReplyAsync( @@ -255,4 +270,93 @@ private async Task PostProcessMessage(ChatResponseMessage oaiMessage) }; } } + + private class OpenAIClientAgent : IStreamingAgent + { + private readonly OpenAIClient openAIClient; + private readonly string modelName; + private readonly float _temperature; + private readonly int _maxTokens = 1024; + private readonly IEnumerable? _functions; + private readonly string _systemMessage; + + public OpenAIClientAgent( + OpenAIClient openAIClient, + string name, + string systemMessage, + string modelName, + float temperature = 0.7f, + int maxTokens = 1024, + IEnumerable? functions = null) + { + this.openAIClient = openAIClient; + this.modelName = modelName; + this.Name = name; + _temperature = temperature; + _maxTokens = maxTokens; + _functions = functions; + _systemMessage = systemMessage; + } + + public string Name { get; } + + public async Task GenerateReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default) + { + var settings = this.CreateChatCompletionsOptions(options, messages); + var reply = await this.openAIClient.GetChatCompletionsAsync(settings, cancellationToken); + + return new MessageEnvelope(reply.Value.Choices.First().Message, from: this.Name); + } + + public Task> GenerateStreamingReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages) + { + var oaiMessages = messages.Select(m => m switch + { + IMessage chatRequestMessage => chatRequestMessage.Content, + _ => throw new ArgumentException("Invalid message type") + }); + + // add system message if there's no system message in messages + if (!oaiMessages.Any(m => m is ChatRequestSystemMessage)) + { + oaiMessages = new[] { new ChatRequestSystemMessage(_systemMessage) }.Concat(oaiMessages); + } + + var settings = new ChatCompletionsOptions(this.modelName, oaiMessages) + { + MaxTokens = options?.MaxToken ?? _maxTokens, + Temperature = options?.Temperature ?? _temperature, + }; + + var functions = options?.Functions ?? _functions; + if (functions is not null && functions.Count() > 0) + { + foreach (var f in functions) + { + settings.Functions.Add(f); + } + } + + if (options?.StopSequence is var sequence && sequence is { Length: > 0 }) + { + foreach (var seq in sequence) + { + settings.StopSequences.Add(seq); + } + } + + return settings; + } + } } diff --git a/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs b/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs index 213fcba9a19..eb13210ab51 100644 --- a/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs +++ b/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs @@ -40,7 +40,7 @@ public Task> InvokeAsync(MiddlewareContext context, I throw new NotImplementedException(); } - private IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages) + public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages) { return messages.SelectMany(m => { @@ -165,12 +165,14 @@ private IEnumerable ProcessIncomingMessagesForSelf(Message m } else if (message.FunctionName is string functionName) { + var msg = new ChatRequestAssistantMessage(content: null) + { + FunctionCall = new FunctionCall(functionName, message.FunctionArguments) + }; + return new[] { - new ChatRequestAssistantMessage(string.Empty) - { - FunctionCall = new FunctionCall(functionName, message.FunctionArguments) - } + msg, }; } else @@ -238,12 +240,12 @@ private IEnumerable ProcessIncomingMessagesForOther(MultiMod private IEnumerable ProcessIncomingMessagesForOther(ToolCallMessage _) { - return [new ChatRequestUserMessage("// ToolCall Message Type is not supported")]; + throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent"); } private IEnumerable ProcessIncomingMessagesForOther(ToolCallResultMessage message) { - return message.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result)); + return message.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName)); } private IEnumerable ProcessIncomingMessagesForOther(Message message) @@ -284,7 +286,7 @@ private IEnumerable ProcessIncomingMessagesForOther(Aggregat // convert as user message var resultMessage = aggregateMessage.Message2; - return ProcessIncomingMessagesForOther(resultMessage); + return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result)); } private IEnumerable ProcessIncomingMessagesWithEmptyFrom(TextMessage message) diff --git a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs index 564d55ce626..2e784296fe0 100644 --- a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs +++ b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs @@ -8,7 +8,7 @@ using ApprovalTests; using ApprovalTests.Namers; using ApprovalTests.Reporters; -using AutoGen.OpenAI; +using AutoGen.OpenAI.Middleware; using Azure.AI.OpenAI; using FluentAssertions; using Xunit; @@ -63,10 +63,10 @@ public void BasicMessageTest() message1: new ToolCallMessage("test", "test", "assistant"), message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"), ]; - + var openaiMessageConnectorMiddleware = new OpenAIMessageConnector(); var agent = new EchoAgent("assistant"); - var oaiMessages = messages.Select(m => (m, agent.ToOpenAIChatRequestMessage(m))); + var oaiMessages = messages.Select(m => (m, openaiMessageConnectorMiddleware.ProcessIncomingMessages(agent, [m]))); VerifyOAIMessages(oaiMessages); } @@ -74,10 +74,11 @@ public void BasicMessageTest() public void ToOpenAIChatRequestMessageTest() { var agent = new EchoAgent("assistant"); + var middleware = new OpenAIMessageConnector(); // user message IMessage message = new TextMessage(Role.User, "Hello", "user"); - var oaiMessages = agent.ToOpenAIChatRequestMessage(message); + var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(1); oaiMessages.First().Should().BeOfType(); @@ -87,7 +88,7 @@ public void ToOpenAIChatRequestMessageTest() // user message test 2 // even if Role is assistant, it should be converted to user message because it is from the user message = new TextMessage(Role.Assistant, "Hello", "user"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(1); oaiMessages.First().Should().BeOfType(); @@ -97,7 +98,7 @@ public void ToOpenAIChatRequestMessageTest() // user message with multimodal content // image message = new ImageMessage(Role.User, "https://example.com/image.png", "user"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(1); oaiMessages.First().Should().BeOfType(); @@ -112,7 +113,7 @@ public void ToOpenAIChatRequestMessageTest() new TextMessage(Role.User, "Hello", "user"), new ImageMessage(Role.User, "https://example.com/image.png", "user"), ], "user"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(1); oaiMessages.First().Should().BeOfType(); @@ -123,7 +124,7 @@ public void ToOpenAIChatRequestMessageTest() // assistant text message message = new TextMessage(Role.Assistant, "How can I help you?", "assistant"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(1); oaiMessages.First().Should().BeOfType(); @@ -132,7 +133,7 @@ public void ToOpenAIChatRequestMessageTest() // assistant text message with single tool call message = new ToolCallMessage("test", "test", "assistant"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(1); oaiMessages.First().Should().BeOfType(); @@ -143,7 +144,7 @@ public void ToOpenAIChatRequestMessageTest() // user should not suppose to send tool call message message = new ToolCallMessage("test", "test", "user"); - Func action = () => agent.ToOpenAIChatRequestMessage(message).First(); + Func action = () => middleware.ProcessIncomingMessages(agent, [message]).First(); action.Should().Throw().WithMessage("ToolCallMessage is not supported when message.From is not the same with agent"); // assistant text message with multiple tool calls @@ -154,7 +155,7 @@ public void ToOpenAIChatRequestMessageTest() new ToolCall("test", "test"), ], "assistant"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(1); oaiMessages.First().Should().BeOfType(); @@ -164,7 +165,7 @@ public void ToOpenAIChatRequestMessageTest() // tool call result message message = new ToolCallResultMessage("result", "test", "test", "user"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(1); oaiMessages.First().Should().BeOfType(); @@ -179,7 +180,7 @@ public void ToOpenAIChatRequestMessageTest() new ToolCall("result", "test", "test"), ], "user"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(2); oaiMessages.First().Should().BeOfType(); @@ -195,7 +196,7 @@ public void ToOpenAIChatRequestMessageTest() message1: new ToolCallMessage("test", "test", "assistant"), message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(2); oaiMessages.First().Should().BeOfType(); @@ -213,7 +214,7 @@ public void ToOpenAIChatRequestMessageTest() message1: new ToolCallMessage("test", "test", "user"), message2: new ToolCallResultMessage("result", "test", "test", "user"), "user"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(1); oaiMessages.First().Should().BeOfType(); @@ -236,14 +237,14 @@ public void ToOpenAIChatRequestMessageTest() new ToolCall("result", "test", "test"), ], from: "user"), "user"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(2); oaiMessages.First().Should().BeOfType(); oaiMessages.Last().Should().BeOfType(); // system message message = new TextMessage(Role.System, "You are a helpful AI assistant"); - oaiMessages = agent.ToOpenAIChatRequestMessage(message); + oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(1); oaiMessages.First().Should().BeOfType(); } @@ -252,7 +253,7 @@ public void ToOpenAIChatRequestMessageTest() public void ToOpenAIChatRequestMessageShortCircuitTest() { var agent = new EchoAgent("assistant"); - + var middleware = new OpenAIMessageConnector(); ChatRequestMessage[] messages = [ new ChatRequestUserMessage("Hello"), @@ -265,7 +266,7 @@ public void ToOpenAIChatRequestMessageShortCircuitTest() foreach (var oaiMessage in messages) { IMessage message = new MessageEnvelope(oaiMessage); - var oaiMessages = agent.ToOpenAIChatRequestMessage(message); + var oaiMessages = middleware.ProcessIncomingMessages(agent, [message]); oaiMessages.Count().Should().Be(1); oaiMessages.First().Should().Be(oaiMessage); } diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index 0c8f00f00e3..f0a9d9a68a5 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -80,7 +80,7 @@ public async Task GPTAgentVisionTestAsync() labelResponse.Should().BeOfType(); labelResponse.From.Should().Be(gpt3Agent.Name); labelResponse.GetContent().Should().Be("[HIGHEST_LABEL] gpt-4 (n=5) green"); - labelResponse.GetToolCalls().First().FunctionName.Should().Be(nameof(GetHighestLabel)); + labelResponse.GetToolCalls()!.First().FunctionName.Should().Be(nameof(GetHighestLabel)); } [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] @@ -205,27 +205,25 @@ public async Task GetHighestLabel(string labelName, string color) private async Task EchoFunctionCallTestAsync(IAgent agent) { - var message = new Message(Role.System, "You are a helpful AI assistant that call echo function"); - var helloWorld = new Message(Role.User, "echo Hello world"); + var message = new TextMessage(Role.System, "You are a helpful AI assistant that call echo function"); + var helloWorld = new TextMessage(Role.User, "echo Hello world"); - var reply = await agent.SendAsync(chatHistory: new Message[] { message, helloWorld }); + var reply = await agent.SendAsync(chatHistory: new[] { message, helloWorld }); - reply.GetRole().Should().Be(Role.Assistant); reply.From.Should().Be(agent.Name); - reply.GetToolCalls().First().FunctionName.Should().Be(nameof(EchoAsync)); + reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync)); } private async Task EchoFunctionCallExecutionTestAsync(IAgent agent) { - var message = new Message(Role.System, "You are a helpful AI assistant that echo whatever user says"); - var helloWorld = new Message(Role.User, "echo Hello world"); + var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says"); + var helloWorld = new TextMessage(Role.User, "echo Hello world"); - var reply = await agent.SendAsync(chatHistory: new Message[] { message, helloWorld }); + var reply = await agent.SendAsync(chatHistory: new[] { message, helloWorld }); reply.GetContent().Should().Be("[ECHO] Hello world"); - reply.GetRole().Should().Be(Role.Assistant); reply.From.Should().Be(agent.Name); - reply.GetToolCalls().First().FunctionName.Should().Be(nameof(EchoAsync)); + reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync)); } private async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent) @@ -259,13 +257,12 @@ await foreach (var reply in replyStream) private async Task UpperCaseTest(IAgent agent) { - var message = new Message(Role.System, "You are a helpful AI assistant that convert user message to upper case"); - var uppCaseMessage = new Message(Role.User, "abcdefg"); + var message = new TextMessage(Role.System, "You are a helpful AI assistant that convert user message to upper case"); + var uppCaseMessage = new TextMessage(Role.User, "abcdefg"); - var reply = await agent.SendAsync(chatHistory: new Message[] { message, uppCaseMessage }); + var reply = await agent.SendAsync(chatHistory: new[] { message, uppCaseMessage }); reply.GetContent().Should().Be("ABCDEFG"); - reply.GetRole().Should().Be(Role.Assistant); reply.From.Should().Be(agent.Name); } diff --git a/dotnet/test/AutoGen.Tests/TwoAgentTest.cs b/dotnet/test/AutoGen.Tests/TwoAgentTest.cs index a267374efab..7102066eb3e 100644 --- a/dotnet/test/AutoGen.Tests/TwoAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/TwoAgentTest.cs @@ -61,7 +61,7 @@ public async Task TwoAgentWeatherChatTestAsync() .RegisterMiddleware(async (msgs, option, agent, ct) => { var lastMessage = msgs.Last(); - if (lastMessage.GetToolCalls().FirstOrDefault()?.FunctionName != null) + if (lastMessage.GetToolCalls()!.First()!.FunctionName != null) { return await agent.GenerateReplyAsync(msgs, option, ct); } From fb36566224218815ff72abd076ef7ce1496beae7 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 1 Mar 2024 00:12:41 -0800 Subject: [PATCH 13/27] fix test --- .../Example03_Agent_FunctionCall.cs | 16 ++- ...Example04_Dynamic_GroupChat_Coding_Task.cs | 2 +- .../Example05_Dalle_And_GPT4V.cs | 44 +++----- ...7_Dynamic_GroupChat_Calculate_Fibonacci.cs | 5 +- dotnet/sample/AutoGen.BasicSamples/Program.cs | 4 +- .../Core/Extension/MessageExtension.cs | 100 +++++++++++++++++- .../Core/Middleware/FunctionCallMiddleware.cs | 20 +++- dotnet/src/AutoGen/OpenAI/GPTAgent.cs | 11 +- .../Middleware/OpenAIMessageConnector.cs | 51 ++++++++- dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 2 +- dotnet/test/AutoGen.Tests/TwoAgentTest.cs | 2 +- 11 files changed, 193 insertions(+), 64 deletions(-) diff --git a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs index 74ad84d7c92..56aa6bd624d 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs @@ -76,19 +76,15 @@ public async Task RunAsync() // talk to the assistant agent var upperCase = await agent.SendAsync("convert to upper case: hello world"); - upperCase.Should().BeOfType(); - var upperCaseResult = (ToolCallResultMessage)upperCase; - upperCaseResult.ToolCalls.First().Result?.Should().Be("HELLO WORLD"); - upperCaseResult.ToolCalls.Count().Should().Be(1); + upperCase.Should().BeOfType>(); + upperCase.GetContent()?.Should().Be("HELLO WORLD"); var concatString = await agent.SendAsync("concatenate strings: a, b, c, d, e"); - concatString.Should().BeOfType(); - var concatStringResult = (ToolCallResultMessage)concatString; - concatStringResult.ToolCalls.First().Result?.Should().Be("a b c d e"); + concatString.Should().BeOfType>(); + concatString.GetContent()?.Should().Be("a b c d e"); var calculateTax = await agent.SendAsync("calculate tax: 100, 0.1"); - calculateTax.Should().BeOfType(); - var calculateTaxResult = (ToolCallResultMessage)calculateTax; - calculateTaxResult.ToolCalls.First().Result?.Should().Be("tax is 10"); + calculateTax.Should().BeOfType>(); + calculateTax.GetContent().Should().Be("tax is 10"); } } diff --git a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs index 152bb042346..7539c8bcf3d 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs @@ -47,7 +47,7 @@ public static async Task RunAsync() ConfigList = gptConfig, }); - var userProxy = new UserProxyAgent(name: "user", defaultReply: GroupChatExtension.TERMINATE) + var userProxy = new UserProxyAgent(name: "user", defaultReply: GroupChatExtension.TERMINATE, humanInputMode: HumanInputMode.NEVER) .RegisterPrintFormatMessageHook(); // Create admin agent diff --git a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs index 50525b1e772..4cd1fbdf117 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs @@ -81,7 +81,7 @@ public static async Task RunAsync() { { nameof(GenerateImage), instance.GenerateImageWrapper }, }) - .RegisterReply(async (msgs, ct) => + .RegisterMiddleware(async (msgs, option, agent, ct) => { // if last message contains [TERMINATE], then find the last image url and terminate the conversation if (msgs.Last().GetContent()?.Contains("TERMINATE") is true) @@ -102,7 +102,19 @@ public static async Task RunAsync() }; } - return null; + var reply = await agent.GenerateReplyAsync(msgs, option, ct); + + if (reply.GetContent() is string content && content.Contains("IMAGE_GENERATION")) + { + var imageUrl = content.Split("\n").Last(); + var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: reply.From); + + return imageMessage; + } + else + { + return reply; + } }) .RegisterPrintFormatMessageHook(); @@ -121,34 +133,8 @@ public static async Task RunAsync() { Temperature = 0, ConfigList = gpt4vConfig, - }).RegisterReply(async (msgs, ct) => - { - // if no image is generated, then ask DALL-E agent to generate image - if (msgs.Last() is not ImageMessage) - { - return new TextMessage(Role.Assistant, "Hey dalle, please generate image", from: "gpt4v"); - } - - return null; }) - .RegisterPreProcess(async (msgs, ct) => - { - // add image url to message metadata so it can be recognized by GPT-4V - return msgs.Select(msg => - { - if (msg.GetContent() is string content && content.Contains("IMAGE_GENERATION")) - { - var imageUrl = content.Split("\n").Last(); - var imageMessage = new ImageMessage(Role.Assistant, imageUrl, from: msg.From); - - return imageMessage; - } - else - { - return msg; - } - }); - }).RegisterPrintFormatMessageHook(); + .RegisterPrintFormatMessageHook(); IEnumerable conversation = new List() { diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs index dfe8d201f60..86b3489f731 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs @@ -85,9 +85,10 @@ public static async Task RunAsync() var reply = await innerAgent.GenerateReplyAsync(msgs, option, ct); while (maxRetry-- > 0) { - if (reply is ToolCallResultMessage toolResultMessage && toolResultMessage.ToolCalls is { Count: 1 } && toolResultMessage.ToolCalls[0].FunctionName == nameof(ReviewCodeBlock)) + if (reply.GetToolCalls() is var toolCalls && toolCalls.Count() == 1 && toolCalls[0].FunctionName == nameof(ReviewCodeBlock)) { - var reviewResultObj = JsonSerializer.Deserialize(toolResultMessage.ToolCalls[0].Result); + var toolCallResult = reply.GetContent(); + var reviewResultObj = JsonSerializer.Deserialize(toolCallResult); var reviews = new List(); if (reviewResultObj.HasMultipleCodeBlocks) { diff --git a/dotnet/sample/AutoGen.BasicSamples/Program.cs b/dotnet/sample/AutoGen.BasicSamples/Program.cs index 665655591ee..058cd5fa044 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Program.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Program.cs @@ -1,6 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Program.cs -using AutoGen.BasicSample; - -await Example10_SemanticKernel.RunAsync(); +await Example05_Dalle_And_GPT4V.RunAsync(); diff --git a/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs b/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs index 35a0c1987da..e529571a503 100644 --- a/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs +++ b/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs @@ -16,10 +16,100 @@ public static string FormatMessage(this IMessage message) return message switch { Message msg => msg.FormatMessage(), + TextMessage textMessage => textMessage.FormatMessage(), + ImageMessage imageMessage => imageMessage.FormatMessage(), + ToolCallMessage toolCallMessage => toolCallMessage.FormatMessage(), + ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.FormatMessage(), + AggregateMessage aggregateMessage => aggregateMessage.FormatMessage(), _ => message.ToString(), }; } + public static string FormatMessage(this TextMessage message) + { + var sb = new StringBuilder(); + // write from + sb.AppendLine($"TextMessage from {message.From}"); + // write a seperator + sb.AppendLine(separator); + sb.AppendLine(message.Content); + // write a seperator + sb.AppendLine(separator); + + return sb.ToString(); + } + + public static string FormatMessage(this ImageMessage message) + { + var sb = new StringBuilder(); + // write from + sb.AppendLine($"ImageMessage from {message.From}"); + // write a seperator + sb.AppendLine(separator); + sb.AppendLine($"Image: {message.Url}"); + // write a seperator + sb.AppendLine(separator); + + return sb.ToString(); + } + + public static string FormatMessage(this ToolCallMessage message) + { + var sb = new StringBuilder(); + // write from + sb.AppendLine($"ToolCallMessage from {message.From}"); + + // write a seperator + sb.AppendLine(separator); + + foreach (var toolCall in message.ToolCalls) + { + sb.AppendLine($"- {toolCall.FunctionName}: {toolCall.FunctionArguments}"); + } + + sb.AppendLine(separator); + + return sb.ToString(); + } + + public static string FormatMessage(this ToolCallResultMessage message) + { + var sb = new StringBuilder(); + // write from + sb.AppendLine($"ToolCallResultMessage from {message.From}"); + + // write a seperator + sb.AppendLine(separator); + + foreach (var toolCall in message.ToolCalls) + { + sb.AppendLine($"- {toolCall.FunctionName}: {toolCall.Result}"); + } + + sb.AppendLine(separator); + + return sb.ToString(); + } + + public static string FormatMessage(this AggregateMessage message) + { + var sb = new StringBuilder(); + // write from + sb.AppendLine($"AggregateMessage from {message.From}"); + + // write a seperator + sb.AppendLine(separator); + + sb.AppendLine("ToolCallMessage:"); + sb.AppendLine(message.Message1.FormatMessage()); + + sb.AppendLine("ToolCallResultMessage:"); + sb.AppendLine(message.Message2.FormatMessage()); + + sb.AppendLine(separator); + + return sb.ToString(); + } public static string FormatMessage(this Message message) { var sb = new StringBuilder(); @@ -68,6 +158,8 @@ public static bool IsSystemMessage(this IMessage message) /// Get the content from the message /// if the message is a or , return the content /// if the message is a and only contains one function call, return the result of that function call + /// if the message is a where TMessage1 is and TMessage2 is and the second message only contains one function call, return the result of that function call + /// for all other situation, return null. /// /// public static string? GetContent(this IMessage message) @@ -76,7 +168,8 @@ public static bool IsSystemMessage(this IMessage message) { TextMessage textMessage => textMessage.Content, Message msg => msg.Content, - ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.GetToolCalls().Count() == 1 ? toolCallResultMessage.GetToolCalls().First().Result : null, + ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls.Count == 1 ? toolCallResultMessage.ToolCalls.First().Result : null, + AggregateMessage aggregateMessage => aggregateMessage.Message2.ToolCalls.Count == 1 ? aggregateMessage.Message2.ToolCalls.First().Result : null, _ => null, }; } @@ -98,6 +191,9 @@ public static bool IsSystemMessage(this IMessage message) /// /// Return the tool calls from the message if it's available. + /// if the message is a , return its tool calls + /// if the message is a and the function name and function arguments are available, return a list of tool call with one item + /// if the message is a where TMessage1 is and TMessage2 is , return the tool calls from the first message /// /// /// @@ -106,11 +202,11 @@ public static bool IsSystemMessage(this IMessage message) return message switch { ToolCallMessage toolCallMessage => toolCallMessage.ToolCalls, - ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls, Message msg => msg.FunctionName is not null && msg.FunctionArguments is not null ? msg.Content is not null ? new List { new ToolCall(msg.FunctionName, msg.FunctionArguments, result: msg.Content) } : new List { new ToolCall(msg.FunctionName, msg.FunctionArguments) } : null, + AggregateMessage aggregateMessage => aggregateMessage.Message1.ToolCalls, _ => null, }; } diff --git a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs index 9445f9180a2..8792edd68b2 100644 --- a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs +++ b/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs @@ -12,6 +12,17 @@ namespace AutoGen.Core.Middleware; /// /// The middleware that process function call message that both send to an agent or reply from an agent. +/// If the last message is and the tool calls is available in this middleware's function map, +/// the tools from the last message will be invoked and a will be returned. In this situation, +/// the inner agent will be short-cut and won't be invoked. +/// Otherwise, the message will be sent to the inner agent. In this situation +/// if the reply from the inner agent is , +/// and the tool calls is available in this middleware's function map, the tools from the reply will be invoked, +/// and a where TMessage1 is and TMessage2 is "/> +/// will be returned. +/// +/// If the reply from the inner agent is but the tool calls is not available in this middleware's function map, +/// or the reply from the inner agent is not , the original reply from the inner agent will be returned. /// public class FunctionCallMiddleware : IMiddleware { @@ -33,9 +44,10 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, { // if the last message is a function call message, invoke the function and return the result instead of sending to the agent. var lastMessage = context.Messages.Last(); - if (lastMessage.GetToolCalls() is IList toolCalls && toolCalls.Count() == 1) + if (lastMessage is ToolCallMessage toolCallMessage) { var toolCallResult = new List(); + var toolCalls = toolCallMessage.ToolCalls; foreach (var toolCall in toolCalls) { var functionName = toolCall.FunctionName; @@ -68,8 +80,9 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, var reply = await agent.GenerateReplyAsync(context.Messages, options, cancellationToken); // if the reply is a function call message plus the function's name is available in function map, invoke the function and return the result instead of sending to the agent. - if (reply.GetToolCalls() is IList toolCallsReply) + if (reply is ToolCallMessage toolCallMsg) { + var toolCallsReply = toolCallMsg.ToolCalls; var toolCallResult = new List(); foreach (var toolCall in toolCallsReply) { @@ -84,7 +97,8 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, if (toolCallResult.Count() > 0) { - return new ToolCallResultMessage(toolCallResult, from: agent.Name); + var toolCallResultMessage = new ToolCallResultMessage(toolCallResult, from: agent.Name); + return new AggregateMessage(toolCallMsg, toolCallResultMessage, from: agent.Name); } else { diff --git a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs index f3e2f5bf949..92f1ed360aa 100644 --- a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs +++ b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs @@ -91,14 +91,7 @@ public class GPTAgent : IStreamingAgent agent = agent.RegisterMiddleware(functionMapMiddleware); } - var reply = await agent.GenerateReplyAsync(messages, options, cancellationToken); - - if (reply is IMessage oaiMessage) - { - return await this.PostProcessMessage(oaiMessage.Content); - } - - throw new Exception("Invalid message type"); + return await agent.GenerateReplyAsync(messages, options, cancellationToken); } public async Task> GenerateStreamingReplyAsync( @@ -344,7 +337,7 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions { foreach (var f in functions) { - settings.Functions.Add(f); + settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); } } diff --git a/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs b/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs index eb13210ab51..25b9eb99b76 100644 --- a/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs +++ b/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs @@ -27,12 +27,14 @@ public class OpenAIMessageConnector : IMiddleware, IStreamingMiddleware { public string? Name => nameof(OpenAIMessageConnector); - public Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) + public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { var chatMessages = ProcessIncomingMessages(agent, context.Messages) .Select(m => new MessageEnvelope(m)); - return agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken); + var reply = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken); + + return PostProcessMessage(reply); } public Task> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) @@ -40,6 +42,49 @@ public Task> InvokeAsync(MiddlewareContext context, I throw new NotImplementedException(); } + public IMessage PostProcessMessage(IMessage message) + { + return message switch + { + TextMessage => message, + ImageMessage => message, + MultiModalMessage => message, + ToolCallMessage => message, + ToolCallResultMessage => message, + Message => message, + AggregateMessage => message, + IMessage m => PostProcessMessage(m), + _ => throw new InvalidOperationException("The type of message is not supported. Must be one of TextMessage, ImageMessage, MultiModalMessage, ToolCallMessage, ToolCallResultMessage, Message, IMessage, AggregateMessage"), + }; + } + + private IMessage PostProcessMessage(IMessage message) + { + var chatResponseMessage = message.Content; + if (chatResponseMessage.Content is string content) + { + return new TextMessage(Role.Assistant, content, message.From); + } + + if (chatResponseMessage.FunctionCall is FunctionCall functionCall) + { + return new ToolCallMessage(functionCall.Name, functionCall.Arguments, message.From); + } + + if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any()) + { + var functionToolCalls = chatResponseMessage.ToolCalls + .Where(tc => tc is ChatCompletionsFunctionToolCall) + .Select(tc => (ChatCompletionsFunctionToolCall)tc); + + var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments)); + + return new ToolCallMessage(toolCalls, message.From); + } + + throw new InvalidOperationException("Invalid ChatResponseMessage"); + } + public IEnumerable ProcessIncomingMessages(IAgent agent, IEnumerable messages) { return messages.SelectMany(m => @@ -160,7 +205,7 @@ private IEnumerable ProcessIncomingMessagesForSelf(Message m } else { - return new[] { new ChatRequestFunctionMessage(message.FunctionName, content) }; + return new[] { new ChatRequestToolMessage(content, message.FunctionName) }; } } else if (message.FunctionName is string functionName) diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index f0a9d9a68a5..b5f6f037d13 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -223,7 +223,7 @@ private async Task EchoFunctionCallExecutionTestAsync(IAgent agent) reply.GetContent().Should().Be("[ECHO] Hello world"); reply.From.Should().Be(agent.Name); - reply.GetToolCalls()!.First().FunctionName.Should().Be(nameof(EchoAsync)); + reply.Should().BeOfType>(); } private async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent) diff --git a/dotnet/test/AutoGen.Tests/TwoAgentTest.cs b/dotnet/test/AutoGen.Tests/TwoAgentTest.cs index 7102066eb3e..1f5f5b01f9d 100644 --- a/dotnet/test/AutoGen.Tests/TwoAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/TwoAgentTest.cs @@ -61,7 +61,7 @@ public async Task TwoAgentWeatherChatTestAsync() .RegisterMiddleware(async (msgs, option, agent, ct) => { var lastMessage = msgs.Last(); - if (lastMessage.GetToolCalls()!.First()!.FunctionName != null) + if (lastMessage.GetToolCalls()?.FirstOrDefault()?.FunctionName != null) { return await agent.GenerateReplyAsync(msgs, option, ct); } From bc3e1df8347ed9c1cd8c15b8a041f39425a1b890 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 1 Mar 2024 11:01:58 -0800 Subject: [PATCH 14/27] add comments --- dotnet/src/AutoGen/Core/Message/IMessage.cs | 32 ++++ dotnet/src/AutoGen/OpenAI/GPTAgent.cs | 147 ++---------------- .../src/AutoGen/OpenAI/OpenAIClientAgent.cs | 116 ++++++++++++++ dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 25 +-- 4 files changed, 181 insertions(+), 139 deletions(-) create mode 100644 dotnet/src/AutoGen/OpenAI/OpenAIClientAgent.cs diff --git a/dotnet/src/AutoGen/Core/Message/IMessage.cs b/dotnet/src/AutoGen/Core/Message/IMessage.cs index 9ade6c1ab66..2d6671be3c8 100644 --- a/dotnet/src/AutoGen/Core/Message/IMessage.cs +++ b/dotnet/src/AutoGen/Core/Message/IMessage.cs @@ -1,8 +1,40 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // IMessage.cs +using AutoGen.Core.Middleware; + namespace AutoGen; +/// +/// The universal message interface for all message types in AutoGen. +/// Related PR: https://github.com/microsoft/autogen/pull/1676 +/// Built-in message types +/// +/// +/// : plain text message. +/// +/// +/// : image message. +/// +/// +/// : message type for multimodal message. The current support message items are and . +/// +/// +/// : message type for tool call. This message supports both single and parallel tool call. +/// +/// +/// : message type for tool call result. +/// +/// +/// : This type is used by previous version of AutoGen. And it's reserved for backward compatibility. +/// +/// +/// : an aggregate message type that contains two message types. +/// This type is useful when you want to combine two message types into one unique message type. One example is when invoking a tool call and you want to return both and . +/// One example of how this type is used in AutoGen is +/// +/// +/// public interface IMessage { string? From { get; set; } diff --git a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs index 92f1ed360aa..54f271f8d47 100644 --- a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs +++ b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs @@ -12,6 +12,23 @@ namespace AutoGen.OpenAI; +/// +/// GPT agent that can be used to connect to OpenAI chat models like GPT-3.5, GPT-4, etc. +/// supports the following message types as input: +/// - +/// - +/// - +/// - +/// - +/// - +/// - where T is +/// - where TMessage1 is and TMessage2 is +/// +/// returns the following message types: +/// - +/// - +/// - where TMessage1 is and TMessage2 is +/// public class GPTAgent : IStreamingAgent { private readonly string _systemMessage; @@ -222,134 +239,4 @@ private IEnumerable ProcessMessages(IEnumerable me return openAIMessages; } - - private async Task PostProcessMessage(ChatResponseMessage oaiMessage) - { - if (this.functionMap != null && oaiMessage.FunctionCall is FunctionCall fc) - { - if (this.functionMap.TryGetValue(fc.Name, out var func)) - { - var result = await func(fc.Arguments); - return new Message(Role.Assistant, result, from: this.Name) - { - FunctionName = fc.Name, - FunctionArguments = fc.Arguments, - Value = oaiMessage, - }; - } - else - { - var errorMessage = $"Function {fc.Name} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}"; - return new Message(Role.Assistant, errorMessage, from: this.Name) - { - FunctionName = fc.Name, - FunctionArguments = fc.Arguments, - Value = oaiMessage, - }; - } - } - else - { - if (string.IsNullOrEmpty(oaiMessage.Content) && oaiMessage.FunctionCall is null) - { - throw new Exception("OpenAI response is invalid."); - } - return new Message(Role.Assistant, oaiMessage.Content) - { - From = this.Name, - FunctionName = oaiMessage.FunctionCall?.Name, - FunctionArguments = oaiMessage.FunctionCall?.Arguments, - Value = oaiMessage, - }; - } - } - - private class OpenAIClientAgent : IStreamingAgent - { - private readonly OpenAIClient openAIClient; - private readonly string modelName; - private readonly float _temperature; - private readonly int _maxTokens = 1024; - private readonly IEnumerable? _functions; - private readonly string _systemMessage; - - public OpenAIClientAgent( - OpenAIClient openAIClient, - string name, - string systemMessage, - string modelName, - float temperature = 0.7f, - int maxTokens = 1024, - IEnumerable? functions = null) - { - this.openAIClient = openAIClient; - this.modelName = modelName; - this.Name = name; - _temperature = temperature; - _maxTokens = maxTokens; - _functions = functions; - _systemMessage = systemMessage; - } - - public string Name { get; } - - public async Task GenerateReplyAsync( - IEnumerable messages, - GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default) - { - var settings = this.CreateChatCompletionsOptions(options, messages); - var reply = await this.openAIClient.GetChatCompletionsAsync(settings, cancellationToken); - - return new MessageEnvelope(reply.Value.Choices.First().Message, from: this.Name); - } - - public Task> GenerateStreamingReplyAsync( - IEnumerable messages, - GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default) - { - throw new NotImplementedException(); - } - - private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages) - { - var oaiMessages = messages.Select(m => m switch - { - IMessage chatRequestMessage => chatRequestMessage.Content, - _ => throw new ArgumentException("Invalid message type") - }); - - // add system message if there's no system message in messages - if (!oaiMessages.Any(m => m is ChatRequestSystemMessage)) - { - oaiMessages = new[] { new ChatRequestSystemMessage(_systemMessage) }.Concat(oaiMessages); - } - - var settings = new ChatCompletionsOptions(this.modelName, oaiMessages) - { - MaxTokens = options?.MaxToken ?? _maxTokens, - Temperature = options?.Temperature ?? _temperature, - }; - - var functions = options?.Functions ?? _functions; - if (functions is not null && functions.Count() > 0) - { - foreach (var f in functions) - { - settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); - } - } - - if (options?.StopSequence is var sequence && sequence is { Length: > 0 }) - { - foreach (var seq in sequence) - { - settings.StopSequences.Add(seq); - } - } - - return settings; - } - } } diff --git a/dotnet/src/AutoGen/OpenAI/OpenAIClientAgent.cs b/dotnet/src/AutoGen/OpenAI/OpenAIClientAgent.cs new file mode 100644 index 00000000000..cd1d66da17d --- /dev/null +++ b/dotnet/src/AutoGen/OpenAI/OpenAIClientAgent.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OpenAIClientAgent.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure.AI.OpenAI; + +namespace AutoGen.OpenAI; + +/// +/// OpenAI client agent. This agent is a thin wrapper around to provide a simple interface for chat completions. +/// To better work with other agents, it's recommended to use which supports more message types and have a better compatibility with other agents. +/// supports the following message types: +/// +/// +/// where T is : chat request message. +/// +/// +/// returns the following message types: +/// +/// +/// where T is : chat response message. +/// +/// +/// +public class OpenAIClientAgent : IStreamingAgent +{ + private readonly OpenAIClient openAIClient; + private readonly string modelName; + private readonly float _temperature; + private readonly int _maxTokens = 1024; + private readonly IEnumerable? _functions; + private readonly string _systemMessage; + + public OpenAIClientAgent( + OpenAIClient openAIClient, + string name, + string systemMessage, + string modelName, + float temperature = 0.7f, + int maxTokens = 1024, + IEnumerable? functions = null) + { + this.openAIClient = openAIClient; + this.modelName = modelName; + this.Name = name; + _temperature = temperature; + _maxTokens = maxTokens; + _functions = functions; + _systemMessage = systemMessage; + } + + public string Name { get; } + + public async Task GenerateReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default) + { + var settings = this.CreateChatCompletionsOptions(options, messages); + var reply = await this.openAIClient.GetChatCompletionsAsync(settings, cancellationToken); + + return new MessageEnvelope(reply.Value.Choices.First().Message, from: this.Name); + } + + public Task> GenerateStreamingReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages) + { + var oaiMessages = messages.Select(m => m switch + { + IMessage chatRequestMessage => chatRequestMessage.Content, + _ => throw new ArgumentException("Invalid message type") + }); + + // add system message if there's no system message in messages + if (!oaiMessages.Any(m => m is ChatRequestSystemMessage)) + { + oaiMessages = new[] { new ChatRequestSystemMessage(_systemMessage) }.Concat(oaiMessages); + } + + var settings = new ChatCompletionsOptions(this.modelName, oaiMessages) + { + MaxTokens = options?.MaxToken ?? _maxTokens, + Temperature = options?.Temperature ?? _temperature, + }; + + var functions = options?.Functions ?? _functions; + if (functions is not null && functions.Count() > 0) + { + foreach (var f in functions) + { + settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); + } + } + + if (options?.StopSequence is var sequence && sequence is { Length: > 0 }) + { + foreach (var seq in sequence) + { + settings.StopSequences.Add(seq); + } + } + + return settings; + } +} diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index b5f6f037d13..2e2569fa840 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -71,16 +71,23 @@ public async Task GPTAgentVisionTestAsync() var oaiMessage = new ChatRequestUserMessage( new ChatMessageTextContentItem("which label has the highest inference cost"), new ChatMessageImageContentItem(new Uri(@"https://raw.githubusercontent.com/microsoft/autogen/main/website/blog/2023-04-21-LLM-tuning-math/img/level2algebra.png"))); + var multiModalMessage = new MultiModalMessage( + [ + new TextMessage(Role.User, "which label has the highest inference cost", from: "user"), + new ImageMessage(Role.User, @"https://raw.githubusercontent.com/microsoft/autogen/main/website/blog/2023-04-21-LLM-tuning-math/img/level2algebra.png", from: "user"), + ], + from: "user"); + + foreach (var message in new IMessage[] { new MessageEnvelope(oaiMessage), multiModalMessage }) + { + var response = await visionAgent.SendAsync(message); + response.From.Should().Be(visionAgent.Name); - var message = oaiMessage.ToMessage(); - var response = await visionAgent.SendAsync(message); - response.From.Should().Be(visionAgent.Name); - - var labelResponse = await gpt3Agent.SendAsync(response); - labelResponse.Should().BeOfType(); - labelResponse.From.Should().Be(gpt3Agent.Name); - labelResponse.GetContent().Should().Be("[HIGHEST_LABEL] gpt-4 (n=5) green"); - labelResponse.GetToolCalls()!.First().FunctionName.Should().Be(nameof(GetHighestLabel)); + var labelResponse = await gpt3Agent.SendAsync(response); + labelResponse.From.Should().Be(gpt3Agent.Name); + labelResponse.GetContent().Should().Be("[HIGHEST_LABEL] gpt-4 (n=5) green"); + labelResponse.GetToolCalls()!.First().FunctionName.Should().Be(nameof(GetHighestLabel)); + } } [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] From 0d5fdac677cf1931f50efc1f9f060bdb5b0cb526 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 1 Mar 2024 11:47:12 -0800 Subject: [PATCH 15/27] use FunctionContract to replace FunctionDefinition --- .../CodeSnippet/CreateAnAgent.cs | 8 ++++---- .../CodeSnippet/FunctionCallCodeSnippet.cs | 13 ++++++------- .../Example03_Agent_FunctionCall.cs | 8 ++++---- .../Example05_Dalle_And_GPT4V.cs | 4 ++-- ...ample07_Dynamic_GroupChat_Calculate_Fibonacci.cs | 4 ++-- dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs | 8 +++++--- dotnet/src/AutoGen/Core/Agent/IAgent.cs | 3 +-- dotnet/src/AutoGen/Core/ConversableAgentConfig.cs | 3 +-- .../Core/Middleware/FunctionCallMiddleware.cs | 6 +++--- dotnet/src/AutoGen/OpenAI/GPTAgent.cs | 4 +++- .../OpenAI/Middleware/OpenAIMessageConnector.cs | 9 ++++++++- dotnet/src/AutoGen/OpenAI/OpenAIClientAgent.cs | 4 +++- dotnet/test/AutoGen.Tests/MathClassTest.cs | 10 +++++----- dotnet/test/AutoGen.Tests/MiddlewareTest.cs | 2 +- dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 8 ++++---- dotnet/test/AutoGen.Tests/TwoAgentTest.cs | 4 ++-- 16 files changed, 54 insertions(+), 44 deletions(-) diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs index ae78cf91523..e64206e56e9 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/CreateAnAgent.cs @@ -86,9 +86,9 @@ public async Task CodeSnippet4() { llmConfig }, - FunctionDefinitions = new[] + FunctionContracts = new[] { - this.UpperCaseFunction, // The FunctionDefinition object for the UpperCase function + this.UpperCaseFunctionContract, // The FunctionDefinition object for the UpperCase function }, }); @@ -121,9 +121,9 @@ public async Task CodeSnippet5() { llmConfig }, - FunctionDefinitions = new[] + FunctionContracts = new[] { - this.UpperCaseFunction, // The FunctionDefinition object for the UpperCase function + this.UpperCaseFunctionContract, // The FunctionDefinition object for the UpperCase function }, }, functionMap: new Dictionary>> diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs index 6d37c574a4f..31196a6735a 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/FunctionCallCodeSnippet.cs @@ -3,7 +3,6 @@ using AutoGen; using AutoGen.OpenAI; -using AutoGen.OpenAI.Extension; using FluentAssertions; public partial class FunctionCallCodeSnippet @@ -30,9 +29,9 @@ public async Task CodeSnippet4() { llmConfig }, - FunctionDefinitions = new[] + FunctionContracts = new[] { - function.WeatherReportFunctionContract.ToOpenAIFunctionDefinition(), // The FunctionDefinition object for the weather report function + function.WeatherReportFunctionContract, }, }); @@ -67,9 +66,9 @@ public async Task CodeSnippet6() { llmConfig }, - FunctionDefinitions = new[] + FunctionContracts = new[] { - function.WeatherReportFunctionContract.ToOpenAIFunctionDefinition(), // The FunctionDefinition object for the weather report function + function.WeatherReportFunctionContract, }, }, functionMap: new Dictionary>> @@ -100,9 +99,9 @@ public async Task TwoAgentWeatherChatTestAsync() llmConfig: new ConversableAgentConfig { ConfigList = new[] { config }, - FunctionDefinitions = new[] + FunctionContracts = new[] { - function.WeatherReportFunctionContract.ToOpenAIFunctionDefinition(), + function.WeatherReportFunctionContract, }, }); diff --git a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs index 56aa6bd624d..f26bc0fa139 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs @@ -54,11 +54,11 @@ public async Task RunAsync() { Temperature = 0, ConfigList = llmConfig, - FunctionDefinitions = new[] + FunctionContracts = new[] { - ConcatStringFunction, - UpperCaseFunction, - CalculateTaxFunction, + ConcatStringFunctionContract, + UpperCaseFunctionContract, + CalculateTaxFunctionContract, }, }; diff --git a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs index 4cd1fbdf117..ddc556d0fe7 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example05_Dalle_And_GPT4V.cs @@ -72,9 +72,9 @@ public static async Task RunAsync() { Temperature = 0, ConfigList = gpt35Config, - FunctionDefinitions = new[] + FunctionContracts = new[] { - instance.GenerateImageFunction, + instance.GenerateImageFunctionContract, }, }, functionMap: new Dictionary>> diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs index 86b3489f731..c46cd20584e 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs @@ -70,9 +70,9 @@ public static async Task RunAsync() { Temperature = 0, ConfigList = gpt3Config, - FunctionDefinitions = new[] + FunctionContracts = new[] { - functions.ReviewCodeBlockFunction, + functions.ReviewCodeBlockFunctionContract, }, }, functionMap: new Dictionary>>() diff --git a/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs b/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs index 838fef2e858..ce1ef79993a 100644 --- a/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs @@ -36,6 +36,7 @@ public class ConversableAgent : IAgent private readonly HumanInputMode humanInputMode; private readonly IDictionary>>? functionMap; private readonly string systemMessage; + private readonly IEnumerable? functions; public ConversableAgent( string name, @@ -71,6 +72,7 @@ public class ConversableAgent : IAgent this.IsTermination = isTermination; this.systemMessage = systemMessage; this.innerAgent = llmConfig?.ConfigList != null ? this.CreateInnerAgentFromConfigList(llmConfig) : null; + this.functions = llmConfig?.FunctionContracts; } private IAgent? CreateInnerAgentFromConfigList(ConversableAgentConfig config) @@ -82,8 +84,8 @@ public class ConversableAgent : IAgent { null => llmConfig switch { - AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0, functions: config.FunctionDefinitions), - OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0, functions: config.FunctionDefinitions), + AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0), + OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0), _ => throw new ArgumentException($"Unsupported config type {llmConfig.GetType()}"), }, IAgent innerAgent => innerAgent.RegisterReply(async (messages, cancellationToken) => @@ -147,7 +149,7 @@ public class ConversableAgent : IAgent agent.Use(humanInputMiddleware); // process function call - var functionCallMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap); + var functionCallMiddleware = new FunctionCallMiddleware(functions: this.functions, functionMap: this.functionMap); agent.Use(functionCallMiddleware); return await agent.GenerateReplyAsync(messages, overrideOptions, cancellationToken); diff --git a/dotnet/src/AutoGen/Core/Agent/IAgent.cs b/dotnet/src/AutoGen/Core/Agent/IAgent.cs index 2450bee6f34..53a0493d47b 100644 --- a/dotnet/src/AutoGen/Core/Agent/IAgent.cs +++ b/dotnet/src/AutoGen/Core/Agent/IAgent.cs @@ -5,7 +5,6 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -using Azure.AI.OpenAI; namespace AutoGen; @@ -48,5 +47,5 @@ public GenerateReplyOptions(GenerateReplyOptions other) public string[]? StopSequence { get; set; } - public FunctionDefinition[]? Functions { get; set; } + public FunctionContract[]? Functions { get; set; } } diff --git a/dotnet/src/AutoGen/Core/ConversableAgentConfig.cs b/dotnet/src/AutoGen/Core/ConversableAgentConfig.cs index 6bf78ec5452..50a83ba8620 100644 --- a/dotnet/src/AutoGen/Core/ConversableAgentConfig.cs +++ b/dotnet/src/AutoGen/Core/ConversableAgentConfig.cs @@ -2,13 +2,12 @@ // ConversableAgentConfig.cs using System.Collections.Generic; -using Azure.AI.OpenAI; namespace AutoGen; public class ConversableAgentConfig { - public IEnumerable? FunctionDefinitions { get; set; } + public IEnumerable? FunctionContracts { get; set; } public IEnumerable? ConfigList { get; set; } diff --git a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs index 8792edd68b2..b7190611bf4 100644 --- a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs +++ b/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs @@ -6,7 +6,6 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -using Azure.AI.OpenAI; namespace AutoGen.Core.Middleware; @@ -26,10 +25,11 @@ namespace AutoGen.Core.Middleware; /// public class FunctionCallMiddleware : IMiddleware { - private readonly IEnumerable? functions; + private readonly IEnumerable? functions; private readonly IDictionary>>? functionMap; + public FunctionCallMiddleware( - IEnumerable? functions = null, + IEnumerable? functions = null, IDictionary>>? functionMap = null, string? name = null) { diff --git a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs index 54f271f8d47..6bb6cde558d 100644 --- a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs +++ b/dotnet/src/AutoGen/OpenAI/GPTAgent.cs @@ -7,6 +7,7 @@ using System.Threading; using System.Threading.Tasks; using AutoGen.Core.Middleware; +using AutoGen.OpenAI.Extension; using AutoGen.OpenAI.Middleware; using Azure.AI.OpenAI; @@ -203,7 +204,8 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions Temperature = options?.Temperature ?? _temperature, }; - var functions = options?.Functions ?? _functions; + var openAIFunctions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition()); + var functions = openAIFunctions ?? _functions; if (functions is not null && functions.Count() > 0) { foreach (var f in functions) diff --git a/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs b/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs index 25b9eb99b76..887dd7fc6ea 100644 --- a/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs +++ b/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs @@ -25,6 +25,13 @@ namespace AutoGen.OpenAI.Middleware; /// public class OpenAIMessageConnector : IMiddleware, IStreamingMiddleware { + private bool strictMode = false; + + public OpenAIMessageConnector(bool strictMode = false) + { + this.strictMode = strictMode; + } + public string? Name => nameof(OpenAIMessageConnector); public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) @@ -283,7 +290,7 @@ private IEnumerable ProcessIncomingMessagesForOther(MultiMod return new[] { new ChatRequestUserMessage(items) }; } - private IEnumerable ProcessIncomingMessagesForOther(ToolCallMessage _) + private IEnumerable ProcessIncomingMessagesForOther(ToolCallMessage msg) { throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent"); } diff --git a/dotnet/src/AutoGen/OpenAI/OpenAIClientAgent.cs b/dotnet/src/AutoGen/OpenAI/OpenAIClientAgent.cs index cd1d66da17d..94148243642 100644 --- a/dotnet/src/AutoGen/OpenAI/OpenAIClientAgent.cs +++ b/dotnet/src/AutoGen/OpenAI/OpenAIClientAgent.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using AutoGen.OpenAI.Extension; using Azure.AI.OpenAI; namespace AutoGen.OpenAI; @@ -94,7 +95,8 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions Temperature = options?.Temperature ?? _temperature, }; - var functions = options?.Functions ?? _functions; + var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition()); + var functions = openAIFunctionDefinitions ?? _functions; if (functions is not null && functions.Count() > 0) { foreach (var f in functions) diff --git a/dotnet/test/AutoGen.Tests/MathClassTest.cs b/dotnet/test/AutoGen.Tests/MathClassTest.cs index 367d31668ce..ad8f7f4e575 100644 --- a/dotnet/test/AutoGen.Tests/MathClassTest.cs +++ b/dotnet/test/AutoGen.Tests/MathClassTest.cs @@ -122,10 +122,10 @@ private async Task CreateTeacherAssistantAgentAsync() { config, }, - FunctionDefinitions = new[] + FunctionContracts = new[] { - this.CreateMathQuestionFunction, - this.AnswerIsCorrectFunction, + this.CreateMathQuestionFunctionContract, + this.AnswerIsCorrectFunctionContract, }, }; @@ -155,9 +155,9 @@ private async Task CreateStudentAssistantAgentAsync() var config = new AzureOpenAIConfig(endPoint, model, key); var llmConfig = new ConversableAgentConfig { - FunctionDefinitions = new[] + FunctionContracts = new[] { - this.AnswerQuestionFunction, + this.AnswerQuestionFunctionContract, }, ConfigList = new[] { diff --git a/dotnet/test/AutoGen.Tests/MiddlewareTest.cs b/dotnet/test/AutoGen.Tests/MiddlewareTest.cs index 5fbbf44f00a..dea20e8ccb4 100644 --- a/dotnet/test/AutoGen.Tests/MiddlewareTest.cs +++ b/dotnet/test/AutoGen.Tests/MiddlewareTest.cs @@ -100,7 +100,7 @@ public async Task FunctionCallMiddlewareTestAsync() // test 2 // middleware should invoke function call if agent reply is a function call message mw = new FunctionCallMiddleware( - functions: [this.EchoFunction], + functions: [this.EchoFunctionContract], functionMap: new Dictionary>> { { "echo", EchoWrapper } }); testAgent = functionCallAgent.RegisterMiddleware(mw); reply = await testAgent.SendAsync("hello"); diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index 2e2569fa840..92ce8079461 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -108,9 +108,9 @@ public async Task AssistantAgentFunctionCallTestAsync() var llmConfig = new ConversableAgentConfig { Temperature = 0, - FunctionDefinitions = new[] + FunctionContracts = new[] { - this.EchoAsyncFunction, + this.EchoAsyncFunctionContract, }, ConfigList = new[] { @@ -148,9 +148,9 @@ public async Task AssistantAgentFunctionCallSelfExecutionTestAsync() var config = this.CreateAzureOpenAIGPT35TurboConfig(); var llmConfig = new ConversableAgentConfig { - FunctionDefinitions = new[] + FunctionContracts = new[] { - this.EchoAsyncFunction, + this.EchoAsyncFunctionContract, }, ConfigList = new[] { diff --git a/dotnet/test/AutoGen.Tests/TwoAgentTest.cs b/dotnet/test/AutoGen.Tests/TwoAgentTest.cs index 1f5f5b01f9d..3faacfeeaf3 100644 --- a/dotnet/test/AutoGen.Tests/TwoAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/TwoAgentTest.cs @@ -38,9 +38,9 @@ public async Task TwoAgentWeatherChatTestAsync() llmConfig: new ConversableAgentConfig { ConfigList = new[] { config }, - FunctionDefinitions = new[] + FunctionContracts = new[] { - this.GetWeatherFunction, + this.GetWeatherFunctionContract, }, }) .RegisterMiddleware(async (msgs, option, agent, ct) => From 4a1e288370b57ab33c5318c781f46b8ccd424690 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 1 Mar 2024 12:25:44 -0800 Subject: [PATCH 16/27] move AutoGen contrac to AutoGen.Core --- dotnet/AutoGen.sln | 16 +++++++++++- .../AutoGen.BasicSample.csproj | 2 -- .../TypeSafeFunctionCallCodeSnippet.cs | 1 - .../AutoGen.BasicSamples/GlobalUsing.cs | 4 +++ .../Agent/DefaultReplyAgent.cs | 2 +- .../Agent/GroupChatManager.cs | 2 +- .../Core => AutoGen.Core}/Agent/IAgent.cs | 3 +-- .../Agent}/IStreamingAgent.cs | 2 +- .../Agent/MiddlewareAgent.cs | 3 +-- .../Agent/MiddlewareStreamingAgent.cs | 3 +-- dotnet/src/AutoGen.Core/AutoGen.Core.csproj | 21 ++++++++++++++++ .../Extension/AgentExtension.cs | 2 +- .../Extension/GroupChatExtension.cs | 2 +- .../Extension/MessageExtension.cs | 2 +- .../Extension/MiddlewareExtension.cs | 3 +-- .../Function/FunctionAttribute.cs | 2 +- .../GroupChat/GroupChat.cs | 2 +- .../GroupChat/SequentialGroupChat.cs | 2 +- .../Core => AutoGen.Core}/IGroupChat.cs | 2 +- .../Core => AutoGen.Core}/ILLMConfig.cs | 2 +- .../Message/AggregateMessage.cs | 2 +- .../Core => AutoGen.Core}/Message/IMessage.cs | 4 +-- .../Message/ImageMessage.cs | 2 +- .../Core => AutoGen.Core}/Message/Message.cs | 9 +++---- .../Message/MessageEnvelope.cs | 2 +- .../Message/MultiModalMessage.cs | 2 +- .../Core => AutoGen.Core}/Message/Role.cs | 2 +- .../Message/TextMessage.cs | 2 +- .../Message/ToolCallMessage.cs | 2 +- .../Message/ToolCallResultMessage.cs | 2 +- .../Middleware/DelegateMiddleware.cs | 2 +- .../Middleware/DelegateStreamingMiddleware.cs | 2 +- .../Middleware/FunctionCallMiddleware.cs | 2 +- .../Middleware/IMiddleware.cs | 2 +- .../Middleware/IStreamingMiddleware.cs | 2 +- .../Middleware/MiddlewareContext.cs | 2 +- .../Middleware/PrintMessageMiddleware.cs | 2 +- .../Workflow/Workflow.cs | 2 +- .../AutoGen.DotnetInteractive/GlobalUsing.cs | 4 +++ .../AutoGen.LMStudio/AutoGen.LMStudio.csproj | 3 ++- dotnet/src/AutoGen.LMStudio/GlobalUsing.cs | 4 +++ dotnet/src/AutoGen.LMStudio/LMStudioConfig.cs | 1 - .../src/AutoGen.OpenAI/AutoGen.OpenAI.csproj | 25 +++++++++++++++++++ .../AzureOpenAIConfig.cs | 0 .../Extension/FunctionContractExtension.cs | 0 .../Extension/MessageExtension.cs | 0 .../OpenAI => AutoGen.OpenAI}/GPTAgent.cs | 1 - dotnet/src/AutoGen.OpenAI/GlobalUsing.cs | 4 +++ .../Middleware/OpenAIMessageConnector.cs | 1 - .../OpenAIClientAgent.cs | 0 .../OpenAI => AutoGen.OpenAI}/OpenAIConfig.cs | 0 .../AutoGen.SemanticKernel.csproj | 6 ++--- .../src/AutoGen.SemanticKernel/GlobalUsing.cs | 4 +++ .../FunctionCallGenerator.cs | 2 +- .../AutoGen/{Core => }/API/LLMConfigAPI.cs | 0 .../{Core => }/Agent/AssistantAgent.cs | 0 .../{Core => }/Agent/ConversableAgent.cs | 1 - .../{Core => }/Agent/UserProxyAgent.cs | 0 dotnet/src/AutoGen/AutoGen.csproj | 14 ++++++++--- .../{Core => }/ConversableAgentConfig.cs | 0 dotnet/src/AutoGen/GlobalUsing.cs | 4 +++ .../Middleware/HumanInputMiddleware.cs | 3 +-- .../GlobalUsing.cs | 4 +++ .../TopLevelStatementFunctionExample.cs | 2 -- dotnet/test/AutoGen.Tests/GlobalUsing.cs | 4 +++ dotnet/test/AutoGen.Tests/MiddlewareTest.cs | 1 - 66 files changed, 146 insertions(+), 65 deletions(-) create mode 100644 dotnet/sample/AutoGen.BasicSamples/GlobalUsing.cs rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Agent/DefaultReplyAgent.cs (96%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Agent/GroupChatManager.cs (97%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Agent/IAgent.cs (98%) rename dotnet/src/{AutoGen/Core => AutoGen.Core/Agent}/IStreamingAgent.cs (95%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Agent/MiddlewareAgent.cs (98%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Agent/MiddlewareStreamingAgent.cs (98%) create mode 100644 dotnet/src/AutoGen.Core/AutoGen.Core.csproj rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Extension/AgentExtension.cs (99%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Extension/GroupChatExtension.cs (99%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Extension/MessageExtension.cs (99%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Extension/MiddlewareExtension.cs (99%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Function/FunctionAttribute.cs (99%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/GroupChat/GroupChat.cs (99%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/GroupChat/SequentialGroupChat.cs (99%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/IGroupChat.cs (94%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/ILLMConfig.cs (82%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Message/AggregateMessage.cs (98%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Message/IMessage.cs (97%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Message/ImageMessage.cs (95%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Message/Message.cs (86%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Message/MessageEnvelope.cs (95%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Message/MultiModalMessage.cs (98%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Message/Role.cs (97%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Message/TextMessage.cs (95%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Message/ToolCallMessage.cs (98%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Message/ToolCallResultMessage.cs (98%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Middleware/DelegateMiddleware.cs (97%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Middleware/DelegateStreamingMiddleware.cs (96%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Middleware/FunctionCallMiddleware.cs (99%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Middleware/IMiddleware.cs (94%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Middleware/IStreamingMiddleware.cs (93%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Middleware/MiddlewareContext.cs (94%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Middleware/PrintMessageMiddleware.cs (95%) rename dotnet/src/{AutoGen/Core => AutoGen.Core}/Workflow/Workflow.cs (99%) create mode 100644 dotnet/src/AutoGen.DotnetInteractive/GlobalUsing.cs create mode 100644 dotnet/src/AutoGen.LMStudio/GlobalUsing.cs create mode 100644 dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj rename dotnet/src/{AutoGen/OpenAI => AutoGen.OpenAI}/AzureOpenAIConfig.cs (100%) rename dotnet/src/{AutoGen/OpenAI => AutoGen.OpenAI}/Extension/FunctionContractExtension.cs (100%) rename dotnet/src/{AutoGen/OpenAI => AutoGen.OpenAI}/Extension/MessageExtension.cs (100%) rename dotnet/src/{AutoGen/OpenAI => AutoGen.OpenAI}/GPTAgent.cs (99%) create mode 100644 dotnet/src/AutoGen.OpenAI/GlobalUsing.cs rename dotnet/src/{AutoGen/OpenAI => AutoGen.OpenAI}/Middleware/OpenAIMessageConnector.cs (99%) rename dotnet/src/{AutoGen/OpenAI => AutoGen.OpenAI}/OpenAIClientAgent.cs (100%) rename dotnet/src/{AutoGen/OpenAI => AutoGen.OpenAI}/OpenAIConfig.cs (100%) create mode 100644 dotnet/src/AutoGen.SemanticKernel/GlobalUsing.cs rename dotnet/src/AutoGen/{Core => }/API/LLMConfigAPI.cs (100%) rename dotnet/src/AutoGen/{Core => }/Agent/AssistantAgent.cs (100%) rename dotnet/src/AutoGen/{Core => }/Agent/ConversableAgent.cs (99%) rename dotnet/src/AutoGen/{Core => }/Agent/UserProxyAgent.cs (100%) rename dotnet/src/AutoGen/{Core => }/ConversableAgentConfig.cs (100%) create mode 100644 dotnet/src/AutoGen/GlobalUsing.cs rename dotnet/src/AutoGen/{Core => }/Middleware/HumanInputMiddleware.cs (98%) create mode 100644 dotnet/test/AutoGen.SourceGenerator.Tests/GlobalUsing.cs create mode 100644 dotnet/test/AutoGen.Tests/GlobalUsing.cs diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index bf3fee6cd93..7a70e04f79e 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -23,7 +23,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.DotnetInteractive", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.LMStudio", "src\AutoGen.LMStudio\AutoGen.LMStudio.csproj", "{F98BDA9B-8657-4BA8-9B03-BAEA454CAE60}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.SemanticKernel", "src\AutoGen.SemanticKernel\AutoGen.SemanticKernel.csproj", "{45D6FC80-36F3-4967-9663-E20B63824621}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AutoGen.SemanticKernel", "src\AutoGen.SemanticKernel\AutoGen.SemanticKernel.csproj", "{45D6FC80-36F3-4967-9663-E20B63824621}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.Core", "src\AutoGen.Core\AutoGen.Core.csproj", "{D58D43D1-0617-4A3D-9932-C773E6398535}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AutoGen.OpenAI", "src\AutoGen.OpenAI\AutoGen.OpenAI.csproj", "{63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -63,6 +67,14 @@ Global {45D6FC80-36F3-4967-9663-E20B63824621}.Debug|Any CPU.Build.0 = Debug|Any CPU {45D6FC80-36F3-4967-9663-E20B63824621}.Release|Any CPU.ActiveCfg = Release|Any CPU {45D6FC80-36F3-4967-9663-E20B63824621}.Release|Any CPU.Build.0 = Release|Any CPU + {D58D43D1-0617-4A3D-9932-C773E6398535}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D58D43D1-0617-4A3D-9932-C773E6398535}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D58D43D1-0617-4A3D-9932-C773E6398535}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D58D43D1-0617-4A3D-9932-C773E6398535}.Release|Any CPU.Build.0 = Release|Any CPU + {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Debug|Any CPU.Build.0 = Debug|Any CPU + {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Release|Any CPU.ActiveCfg = Release|Any CPU + {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -76,6 +88,8 @@ Global {B61D8008-7FB7-4C0E-8044-3A74AA63A596} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {F98BDA9B-8657-4BA8-9B03-BAEA454CAE60} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {45D6FC80-36F3-4967-9663-E20B63824621} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {D58D43D1-0617-4A3D-9932-C773E6398535} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} + {63445BB7-DBB9-4AEF-9D6F-98BBE75EE1EC} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B} diff --git a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj index 75f5c546b29..f2babce7f73 100644 --- a/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj +++ b/dotnet/sample/AutoGen.BasicSamples/AutoGen.BasicSample.csproj @@ -11,8 +11,6 @@ - - diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs index ea00ab44298..1cf53e7fe6e 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/TypeSafeFunctionCallCodeSnippet.cs @@ -2,7 +2,6 @@ // TypeSafeFunctionCallCodeSnippet.cs using System.Text.Json; -using AutoGen; using AutoGen.OpenAI.Extension; using Azure.AI.OpenAI; diff --git a/dotnet/sample/AutoGen.BasicSamples/GlobalUsing.cs b/dotnet/sample/AutoGen.BasicSamples/GlobalUsing.cs new file mode 100644 index 00000000000..d00ae3ce4fc --- /dev/null +++ b/dotnet/sample/AutoGen.BasicSamples/GlobalUsing.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// globalUsing.cs + +global using AutoGen.Core; diff --git a/dotnet/src/AutoGen/Core/Agent/DefaultReplyAgent.cs b/dotnet/src/AutoGen.Core/Agent/DefaultReplyAgent.cs similarity index 96% rename from dotnet/src/AutoGen/Core/Agent/DefaultReplyAgent.cs rename to dotnet/src/AutoGen.Core/Agent/DefaultReplyAgent.cs index 3c05440166c..647a2ece79d 100644 --- a/dotnet/src/AutoGen/Core/Agent/DefaultReplyAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/DefaultReplyAgent.cs @@ -5,7 +5,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen; +namespace AutoGen.Core; public class DefaultReplyAgent : IAgent { diff --git a/dotnet/src/AutoGen/Core/Agent/GroupChatManager.cs b/dotnet/src/AutoGen.Core/Agent/GroupChatManager.cs similarity index 97% rename from dotnet/src/AutoGen/Core/Agent/GroupChatManager.cs rename to dotnet/src/AutoGen.Core/Agent/GroupChatManager.cs index 6c92041cc04..db40f801dea 100644 --- a/dotnet/src/AutoGen/Core/Agent/GroupChatManager.cs +++ b/dotnet/src/AutoGen.Core/Agent/GroupChatManager.cs @@ -7,7 +7,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen; +namespace AutoGen.Core; public class GroupChatManager : IAgent { diff --git a/dotnet/src/AutoGen/Core/Agent/IAgent.cs b/dotnet/src/AutoGen.Core/Agent/IAgent.cs similarity index 98% rename from dotnet/src/AutoGen/Core/Agent/IAgent.cs rename to dotnet/src/AutoGen.Core/Agent/IAgent.cs index 53a0493d47b..b9149008480 100644 --- a/dotnet/src/AutoGen/Core/Agent/IAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/IAgent.cs @@ -6,8 +6,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen; - +namespace AutoGen.Core; public interface IAgent { public string Name { get; } diff --git a/dotnet/src/AutoGen/Core/IStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs similarity index 95% rename from dotnet/src/AutoGen/Core/IStreamingAgent.cs rename to dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs index 894c4726882..3fa121a7b08 100644 --- a/dotnet/src/AutoGen/Core/IStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs @@ -5,7 +5,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen; +namespace AutoGen.Core; /// /// agent that supports streaming reply diff --git a/dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs similarity index 98% rename from dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs rename to dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs index 3c4e99e9b2f..71c8bb7e514 100644 --- a/dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs @@ -5,9 +5,8 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using AutoGen.Core.Middleware; -namespace AutoGen; +namespace AutoGen.Core; /// /// An agent that allows you to add middleware and modify the behavior of an existing agent. diff --git a/dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs similarity index 98% rename from dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs rename to dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs index 4684e15fbb1..3aaba4da61c 100644 --- a/dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs @@ -5,9 +5,8 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using AutoGen.Core.Middleware; -namespace AutoGen; +namespace AutoGen.Core; public class MiddlewareStreamingAgent : IStreamingAgent { diff --git a/dotnet/src/AutoGen.Core/AutoGen.Core.csproj b/dotnet/src/AutoGen.Core/AutoGen.Core.csproj new file mode 100644 index 00000000000..018cd23a446 --- /dev/null +++ b/dotnet/src/AutoGen.Core/AutoGen.Core.csproj @@ -0,0 +1,21 @@ + + + netstandard2.0 + AutoGen.Core + + + + + + + AutoGen.Core + + Core library for AutoGen. This package provides contracts and core functionalities for AutoGen. + + + + + + + + diff --git a/dotnet/src/AutoGen/Core/Extension/AgentExtension.cs b/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs similarity index 99% rename from dotnet/src/AutoGen/Core/Extension/AgentExtension.cs rename to dotnet/src/AutoGen.Core/Extension/AgentExtension.cs index 9bb10fa9d7b..47968497cf9 100644 --- a/dotnet/src/AutoGen/Core/Extension/AgentExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/AgentExtension.cs @@ -6,7 +6,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen; +namespace AutoGen.Core; public static class AgentExtension { diff --git a/dotnet/src/AutoGen/Core/Extension/GroupChatExtension.cs b/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs similarity index 99% rename from dotnet/src/AutoGen/Core/Extension/GroupChatExtension.cs rename to dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs index 5a48c2e83cd..52360887758 100644 --- a/dotnet/src/AutoGen/Core/Extension/GroupChatExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/GroupChatExtension.cs @@ -4,7 +4,7 @@ using System.Collections.Generic; using System.Linq; -namespace AutoGen; +namespace AutoGen.Core; public static class GroupChatExtension { diff --git a/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs b/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs similarity index 99% rename from dotnet/src/AutoGen/Core/Extension/MessageExtension.cs rename to dotnet/src/AutoGen.Core/Extension/MessageExtension.cs index e529571a503..47dbad55e30 100644 --- a/dotnet/src/AutoGen/Core/Extension/MessageExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/MessageExtension.cs @@ -5,7 +5,7 @@ using System.Linq; using System.Text; -namespace AutoGen; +namespace AutoGen.Core; public static class MessageExtension { diff --git a/dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs similarity index 99% rename from dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs rename to dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs index d5223111460..6008b927605 100644 --- a/dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs @@ -5,9 +5,8 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using AutoGen.Core.Middleware; -namespace AutoGen; +namespace AutoGen.Core; public static class MiddlewareExtension { diff --git a/dotnet/src/AutoGen/Core/Function/FunctionAttribute.cs b/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs similarity index 99% rename from dotnet/src/AutoGen/Core/Function/FunctionAttribute.cs rename to dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs index 202b401add3..d0f2d8fa8d8 100644 --- a/dotnet/src/AutoGen/Core/Function/FunctionAttribute.cs +++ b/dotnet/src/AutoGen.Core/Function/FunctionAttribute.cs @@ -4,7 +4,7 @@ using System; using System.Collections.Generic; -namespace AutoGen; +namespace AutoGen.Core; [AttributeUsage(AttributeTargets.Method, Inherited = false, AllowMultiple = false)] public class FunctionAttribute : Attribute diff --git a/dotnet/src/AutoGen/Core/GroupChat/GroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs similarity index 99% rename from dotnet/src/AutoGen/Core/GroupChat/GroupChat.cs rename to dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs index 24b2ad00894..0659f7e7ea1 100644 --- a/dotnet/src/AutoGen/Core/GroupChat/GroupChat.cs +++ b/dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs @@ -7,7 +7,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen; +namespace AutoGen.Core; public class GroupChat : IGroupChat { diff --git a/dotnet/src/AutoGen/Core/GroupChat/SequentialGroupChat.cs b/dotnet/src/AutoGen.Core/GroupChat/SequentialGroupChat.cs similarity index 99% rename from dotnet/src/AutoGen/Core/GroupChat/SequentialGroupChat.cs rename to dotnet/src/AutoGen.Core/GroupChat/SequentialGroupChat.cs index 9a5dcbe4930..85f0f0693c5 100644 --- a/dotnet/src/AutoGen/Core/GroupChat/SequentialGroupChat.cs +++ b/dotnet/src/AutoGen.Core/GroupChat/SequentialGroupChat.cs @@ -7,7 +7,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen; +namespace AutoGen.Core; public class SequentialGroupChat : IGroupChat { diff --git a/dotnet/src/AutoGen/Core/IGroupChat.cs b/dotnet/src/AutoGen.Core/IGroupChat.cs similarity index 94% rename from dotnet/src/AutoGen/Core/IGroupChat.cs rename to dotnet/src/AutoGen.Core/IGroupChat.cs index 4c0be66f4a2..36859c4d1f1 100644 --- a/dotnet/src/AutoGen/Core/IGroupChat.cs +++ b/dotnet/src/AutoGen.Core/IGroupChat.cs @@ -5,7 +5,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen; +namespace AutoGen.Core; public interface IGroupChat { diff --git a/dotnet/src/AutoGen/Core/ILLMConfig.cs b/dotnet/src/AutoGen.Core/ILLMConfig.cs similarity index 82% rename from dotnet/src/AutoGen/Core/ILLMConfig.cs rename to dotnet/src/AutoGen.Core/ILLMConfig.cs index ef836aea5fe..fd2a90db02a 100644 --- a/dotnet/src/AutoGen/Core/ILLMConfig.cs +++ b/dotnet/src/AutoGen.Core/ILLMConfig.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // ILLMConfig.cs -namespace AutoGen; +namespace AutoGen.Core; public interface ILLMConfig { diff --git a/dotnet/src/AutoGen/Core/Message/AggregateMessage.cs b/dotnet/src/AutoGen.Core/Message/AggregateMessage.cs similarity index 98% rename from dotnet/src/AutoGen/Core/Message/AggregateMessage.cs rename to dotnet/src/AutoGen.Core/Message/AggregateMessage.cs index a375f7c2b38..c7eee1316ee 100644 --- a/dotnet/src/AutoGen/Core/Message/AggregateMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/AggregateMessage.cs @@ -4,7 +4,7 @@ using System; using System.Collections.Generic; -namespace AutoGen; +namespace AutoGen.Core; public class AggregateMessage : IMessage where TMessage1 : IMessage diff --git a/dotnet/src/AutoGen/Core/Message/IMessage.cs b/dotnet/src/AutoGen.Core/Message/IMessage.cs similarity index 97% rename from dotnet/src/AutoGen/Core/Message/IMessage.cs rename to dotnet/src/AutoGen.Core/Message/IMessage.cs index 2d6671be3c8..24d8e383875 100644 --- a/dotnet/src/AutoGen/Core/Message/IMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/IMessage.cs @@ -1,9 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // IMessage.cs -using AutoGen.Core.Middleware; - -namespace AutoGen; +namespace AutoGen.Core; /// /// The universal message interface for all message types in AutoGen. diff --git a/dotnet/src/AutoGen/Core/Message/ImageMessage.cs b/dotnet/src/AutoGen.Core/Message/ImageMessage.cs similarity index 95% rename from dotnet/src/AutoGen/Core/Message/ImageMessage.cs rename to dotnet/src/AutoGen.Core/Message/ImageMessage.cs index 4c94b8a63f0..753a6d6e1e4 100644 --- a/dotnet/src/AutoGen/Core/Message/ImageMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/ImageMessage.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // ImageMessage.cs -namespace AutoGen; +namespace AutoGen.Core; public class ImageMessage : IMessage { diff --git a/dotnet/src/AutoGen/Core/Message/Message.cs b/dotnet/src/AutoGen.Core/Message/Message.cs similarity index 86% rename from dotnet/src/AutoGen/Core/Message/Message.cs rename to dotnet/src/AutoGen.Core/Message/Message.cs index f84b296baf9..ec4751b9344 100644 --- a/dotnet/src/AutoGen/Core/Message/Message.cs +++ b/dotnet/src/AutoGen.Core/Message/Message.cs @@ -2,9 +2,8 @@ // Message.cs using System.Collections.Generic; -using Azure.AI.OpenAI; -namespace AutoGen; +namespace AutoGen.Core; public class Message : IMessage { @@ -12,13 +11,13 @@ public class Message : IMessage Role role, string? content, string? from = null, - FunctionCall? functionCall = null) + ToolCall? toolCall = null) { this.Role = role; this.Content = content; this.From = from; - this.FunctionName = functionCall?.Name; - this.FunctionArguments = functionCall?.Arguments; + this.FunctionName = toolCall?.FunctionName; + this.FunctionArguments = toolCall?.FunctionArguments; } public Message(Message other) diff --git a/dotnet/src/AutoGen/Core/Message/MessageEnvelope.cs b/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs similarity index 95% rename from dotnet/src/AutoGen/Core/Message/MessageEnvelope.cs rename to dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs index b3bd85dd50d..a6bad770311 100644 --- a/dotnet/src/AutoGen/Core/Message/MessageEnvelope.cs +++ b/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; -namespace AutoGen; +namespace AutoGen.Core; public class MessageEnvelope : IMessage { diff --git a/dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs b/dotnet/src/AutoGen.Core/Message/MultiModalMessage.cs similarity index 98% rename from dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs rename to dotnet/src/AutoGen.Core/Message/MultiModalMessage.cs index cfbfce677a5..3fe1d34383b 100644 --- a/dotnet/src/AutoGen/Core/Message/MultiModalMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/MultiModalMessage.cs @@ -4,7 +4,7 @@ using System; using System.Collections.Generic; -namespace AutoGen; +namespace AutoGen.Core; public class MultiModalMessage : IMessage { diff --git a/dotnet/src/AutoGen/Core/Message/Role.cs b/dotnet/src/AutoGen.Core/Message/Role.cs similarity index 97% rename from dotnet/src/AutoGen/Core/Message/Role.cs rename to dotnet/src/AutoGen.Core/Message/Role.cs index 4be88007ae9..8253543a81c 100644 --- a/dotnet/src/AutoGen/Core/Message/Role.cs +++ b/dotnet/src/AutoGen.Core/Message/Role.cs @@ -3,7 +3,7 @@ using System; -namespace AutoGen; +namespace AutoGen.Core; public readonly struct Role : IEquatable { diff --git a/dotnet/src/AutoGen/Core/Message/TextMessage.cs b/dotnet/src/AutoGen.Core/Message/TextMessage.cs similarity index 95% rename from dotnet/src/AutoGen/Core/Message/TextMessage.cs rename to dotnet/src/AutoGen.Core/Message/TextMessage.cs index 183907bd5a7..9c9f8c1c828 100644 --- a/dotnet/src/AutoGen/Core/Message/TextMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/TextMessage.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // TextMessage.cs -namespace AutoGen; +namespace AutoGen.Core; public class TextMessage : IMessage { diff --git a/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs similarity index 98% rename from dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs rename to dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs index 1fb291a6551..6a6123037ed 100644 --- a/dotnet/src/AutoGen/Core/Message/ToolCallMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs @@ -5,7 +5,7 @@ using System.Linq; using System.Text; -namespace AutoGen; +namespace AutoGen.Core; public class ToolCall { diff --git a/dotnet/src/AutoGen/Core/Message/ToolCallResultMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallResultMessage.cs similarity index 98% rename from dotnet/src/AutoGen/Core/Message/ToolCallResultMessage.cs rename to dotnet/src/AutoGen.Core/Message/ToolCallResultMessage.cs index 4caa87ea933..99c7740849a 100644 --- a/dotnet/src/AutoGen/Core/Message/ToolCallResultMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/ToolCallResultMessage.cs @@ -5,7 +5,7 @@ using System.Linq; using System.Text; -namespace AutoGen; +namespace AutoGen.Core; public class ToolCallResultMessage : IMessage { diff --git a/dotnet/src/AutoGen/Core/Middleware/DelegateMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/DelegateMiddleware.cs similarity index 97% rename from dotnet/src/AutoGen/Core/Middleware/DelegateMiddleware.cs rename to dotnet/src/AutoGen.Core/Middleware/DelegateMiddleware.cs index 95e9a81dfe3..79360e0428f 100644 --- a/dotnet/src/AutoGen/Core/Middleware/DelegateMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/DelegateMiddleware.cs @@ -5,7 +5,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen.Core.Middleware; +namespace AutoGen.Core; internal class DelegateMiddleware : IMiddleware { diff --git a/dotnet/src/AutoGen/Core/Middleware/DelegateStreamingMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs similarity index 96% rename from dotnet/src/AutoGen/Core/Middleware/DelegateStreamingMiddleware.cs rename to dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs index 453cb61a16d..a366c954193 100644 --- a/dotnet/src/AutoGen/Core/Middleware/DelegateStreamingMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs @@ -5,7 +5,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen.Core.Middleware; +namespace AutoGen.Core; internal class DelegateStreamingMiddleware : IStreamingMiddleware { diff --git a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs similarity index 99% rename from dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs rename to dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs index b7190611bf4..ecff6cf401f 100644 --- a/dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs @@ -7,7 +7,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen.Core.Middleware; +namespace AutoGen.Core; /// /// The middleware that process function call message that both send to an agent or reply from an agent. diff --git a/dotnet/src/AutoGen/Core/Middleware/IMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/IMiddleware.cs similarity index 94% rename from dotnet/src/AutoGen/Core/Middleware/IMiddleware.cs rename to dotnet/src/AutoGen.Core/Middleware/IMiddleware.cs index 461aaebb7f7..2813ee9cdb4 100644 --- a/dotnet/src/AutoGen/Core/Middleware/IMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/IMiddleware.cs @@ -4,7 +4,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen.Core.Middleware; +namespace AutoGen.Core; /// /// The middleware interface diff --git a/dotnet/src/AutoGen/Core/Middleware/IStreamingMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs similarity index 93% rename from dotnet/src/AutoGen/Core/Middleware/IStreamingMiddleware.cs rename to dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs index 5541c6fb571..dc4d98cce7e 100644 --- a/dotnet/src/AutoGen/Core/Middleware/IStreamingMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs @@ -5,7 +5,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen.Core.Middleware; +namespace AutoGen.Core; /// /// The streaming middleware interface diff --git a/dotnet/src/AutoGen/Core/Middleware/MiddlewareContext.cs b/dotnet/src/AutoGen.Core/Middleware/MiddlewareContext.cs similarity index 94% rename from dotnet/src/AutoGen/Core/Middleware/MiddlewareContext.cs rename to dotnet/src/AutoGen.Core/Middleware/MiddlewareContext.cs index 9d2b29787b6..a608d0baf81 100644 --- a/dotnet/src/AutoGen/Core/Middleware/MiddlewareContext.cs +++ b/dotnet/src/AutoGen.Core/Middleware/MiddlewareContext.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; -namespace AutoGen.Core.Middleware; +namespace AutoGen.Core; public class MiddlewareContext { diff --git a/dotnet/src/AutoGen/Core/Middleware/PrintMessageMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs similarity index 95% rename from dotnet/src/AutoGen/Core/Middleware/PrintMessageMiddleware.cs rename to dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs index 33a6118f1c4..9738a3a8e4a 100644 --- a/dotnet/src/AutoGen/Core/Middleware/PrintMessageMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs @@ -5,7 +5,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen.Core.Middleware; +namespace AutoGen.Core; /// /// The middleware that prints the reply from agent to the console. diff --git a/dotnet/src/AutoGen/Core/Workflow/Workflow.cs b/dotnet/src/AutoGen.Core/Workflow/Workflow.cs similarity index 99% rename from dotnet/src/AutoGen/Core/Workflow/Workflow.cs rename to dotnet/src/AutoGen.Core/Workflow/Workflow.cs index 83a795f403c..65a2f8b080d 100644 --- a/dotnet/src/AutoGen/Core/Workflow/Workflow.cs +++ b/dotnet/src/AutoGen.Core/Workflow/Workflow.cs @@ -6,7 +6,7 @@ using System.Linq; using System.Threading.Tasks; -namespace AutoGen; +namespace AutoGen.Core; public class Workflow { diff --git a/dotnet/src/AutoGen.DotnetInteractive/GlobalUsing.cs b/dotnet/src/AutoGen.DotnetInteractive/GlobalUsing.cs new file mode 100644 index 00000000000..d00ae3ce4fc --- /dev/null +++ b/dotnet/src/AutoGen.DotnetInteractive/GlobalUsing.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// globalUsing.cs + +global using AutoGen.Core; diff --git a/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj b/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj index a84d616e3d5..b738fe02bb7 100644 --- a/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj +++ b/dotnet/src/AutoGen.LMStudio/AutoGen.LMStudio.csproj @@ -16,7 +16,8 @@ - + + diff --git a/dotnet/src/AutoGen.LMStudio/GlobalUsing.cs b/dotnet/src/AutoGen.LMStudio/GlobalUsing.cs new file mode 100644 index 00000000000..d00ae3ce4fc --- /dev/null +++ b/dotnet/src/AutoGen.LMStudio/GlobalUsing.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// globalUsing.cs + +global using AutoGen.Core; diff --git a/dotnet/src/AutoGen.LMStudio/LMStudioConfig.cs b/dotnet/src/AutoGen.LMStudio/LMStudioConfig.cs index 445742508fe..4cf18210a43 100644 --- a/dotnet/src/AutoGen.LMStudio/LMStudioConfig.cs +++ b/dotnet/src/AutoGen.LMStudio/LMStudioConfig.cs @@ -2,7 +2,6 @@ // LMStudioConfig.cs using System; -using AutoGen; /// /// Add support for consuming openai-like API from LM Studio diff --git a/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj b/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj new file mode 100644 index 00000000000..182d112227b --- /dev/null +++ b/dotnet/src/AutoGen.OpenAI/AutoGen.OpenAI.csproj @@ -0,0 +1,25 @@ + + + netstandard2.0 + AutoGen.OpenAI + + + + + + + AutoGen.OpenAI + + OpenAI Intergration for AutoGen. + + + + + + + + + + + + diff --git a/dotnet/src/AutoGen/OpenAI/AzureOpenAIConfig.cs b/dotnet/src/AutoGen.OpenAI/AzureOpenAIConfig.cs similarity index 100% rename from dotnet/src/AutoGen/OpenAI/AzureOpenAIConfig.cs rename to dotnet/src/AutoGen.OpenAI/AzureOpenAIConfig.cs diff --git a/dotnet/src/AutoGen/OpenAI/Extension/FunctionContractExtension.cs b/dotnet/src/AutoGen.OpenAI/Extension/FunctionContractExtension.cs similarity index 100% rename from dotnet/src/AutoGen/OpenAI/Extension/FunctionContractExtension.cs rename to dotnet/src/AutoGen.OpenAI/Extension/FunctionContractExtension.cs diff --git a/dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs b/dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs similarity index 100% rename from dotnet/src/AutoGen/OpenAI/Extension/MessageExtension.cs rename to dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs diff --git a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI/GPTAgent.cs similarity index 99% rename from dotnet/src/AutoGen/OpenAI/GPTAgent.cs rename to dotnet/src/AutoGen.OpenAI/GPTAgent.cs index 6bb6cde558d..fcb96588d23 100644 --- a/dotnet/src/AutoGen/OpenAI/GPTAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/GPTAgent.cs @@ -6,7 +6,6 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -using AutoGen.Core.Middleware; using AutoGen.OpenAI.Extension; using AutoGen.OpenAI.Middleware; using Azure.AI.OpenAI; diff --git a/dotnet/src/AutoGen.OpenAI/GlobalUsing.cs b/dotnet/src/AutoGen.OpenAI/GlobalUsing.cs new file mode 100644 index 00000000000..d00ae3ce4fc --- /dev/null +++ b/dotnet/src/AutoGen.OpenAI/GlobalUsing.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// globalUsing.cs + +global using AutoGen.Core; diff --git a/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIMessageConnector.cs similarity index 99% rename from dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs rename to dotnet/src/AutoGen.OpenAI/Middleware/OpenAIMessageConnector.cs index 887dd7fc6ea..b94bec86eb8 100644 --- a/dotnet/src/AutoGen/OpenAI/Middleware/OpenAIMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIMessageConnector.cs @@ -6,7 +6,6 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -using AutoGen.Core.Middleware; using Azure.AI.OpenAI; namespace AutoGen.OpenAI.Middleware; diff --git a/dotnet/src/AutoGen/OpenAI/OpenAIClientAgent.cs b/dotnet/src/AutoGen.OpenAI/OpenAIClientAgent.cs similarity index 100% rename from dotnet/src/AutoGen/OpenAI/OpenAIClientAgent.cs rename to dotnet/src/AutoGen.OpenAI/OpenAIClientAgent.cs diff --git a/dotnet/src/AutoGen/OpenAI/OpenAIConfig.cs b/dotnet/src/AutoGen.OpenAI/OpenAIConfig.cs similarity index 100% rename from dotnet/src/AutoGen/OpenAI/OpenAIConfig.cs rename to dotnet/src/AutoGen.OpenAI/OpenAIConfig.cs diff --git a/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj b/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj index 36e9325663a..70d75006701 100644 --- a/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj +++ b/dotnet/src/AutoGen.SemanticKernel/AutoGen.SemanticKernel.csproj @@ -16,12 +16,12 @@ - + + - - + diff --git a/dotnet/src/AutoGen.SemanticKernel/GlobalUsing.cs b/dotnet/src/AutoGen.SemanticKernel/GlobalUsing.cs new file mode 100644 index 00000000000..d00ae3ce4fc --- /dev/null +++ b/dotnet/src/AutoGen.SemanticKernel/GlobalUsing.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// globalUsing.cs + +global using AutoGen.Core; diff --git a/dotnet/src/AutoGen.SourceGenerator/FunctionCallGenerator.cs b/dotnet/src/AutoGen.SourceGenerator/FunctionCallGenerator.cs index 41f8f525866..50bdc03f0af 100644 --- a/dotnet/src/AutoGen.SourceGenerator/FunctionCallGenerator.cs +++ b/dotnet/src/AutoGen.SourceGenerator/FunctionCallGenerator.cs @@ -17,7 +17,7 @@ namespace AutoGen.SourceGenerator [Generator] public partial class FunctionCallGenerator : IIncrementalGenerator { - private const string FUNCTION_CALL_ATTRIBUTION = "AutoGen.FunctionAttribute"; + private const string FUNCTION_CALL_ATTRIBUTION = "AutoGen.Core.FunctionAttribute"; public void Initialize(IncrementalGeneratorInitializationContext context) { diff --git a/dotnet/src/AutoGen/Core/API/LLMConfigAPI.cs b/dotnet/src/AutoGen/API/LLMConfigAPI.cs similarity index 100% rename from dotnet/src/AutoGen/Core/API/LLMConfigAPI.cs rename to dotnet/src/AutoGen/API/LLMConfigAPI.cs diff --git a/dotnet/src/AutoGen/Core/Agent/AssistantAgent.cs b/dotnet/src/AutoGen/Agent/AssistantAgent.cs similarity index 100% rename from dotnet/src/AutoGen/Core/Agent/AssistantAgent.cs rename to dotnet/src/AutoGen/Agent/AssistantAgent.cs diff --git a/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs b/dotnet/src/AutoGen/Agent/ConversableAgent.cs similarity index 99% rename from dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs rename to dotnet/src/AutoGen/Agent/ConversableAgent.cs index ce1ef79993a..e70a74a801c 100644 --- a/dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs +++ b/dotnet/src/AutoGen/Agent/ConversableAgent.cs @@ -6,7 +6,6 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -using AutoGen.Core.Middleware; using AutoGen.OpenAI; namespace AutoGen; diff --git a/dotnet/src/AutoGen/Core/Agent/UserProxyAgent.cs b/dotnet/src/AutoGen/Agent/UserProxyAgent.cs similarity index 100% rename from dotnet/src/AutoGen/Core/Agent/UserProxyAgent.cs rename to dotnet/src/AutoGen/Agent/UserProxyAgent.cs diff --git a/dotnet/src/AutoGen/AutoGen.csproj b/dotnet/src/AutoGen/AutoGen.csproj index fcf9cf0f829..cc403935d78 100644 --- a/dotnet/src/AutoGen/AutoGen.csproj +++ b/dotnet/src/AutoGen/AutoGen.csproj @@ -1,5 +1,4 @@  - netstandard2.0 AutoGen @@ -11,14 +10,21 @@ AutoGen - cutting-edge LLM multi-agent framework + The all-in-one package for AutoGen. This package provides contracts, core functionalities, OpenAI integration, source generator, etc. for AutoGen. - - + + + + + + + + + diff --git a/dotnet/src/AutoGen/Core/ConversableAgentConfig.cs b/dotnet/src/AutoGen/ConversableAgentConfig.cs similarity index 100% rename from dotnet/src/AutoGen/Core/ConversableAgentConfig.cs rename to dotnet/src/AutoGen/ConversableAgentConfig.cs diff --git a/dotnet/src/AutoGen/GlobalUsing.cs b/dotnet/src/AutoGen/GlobalUsing.cs new file mode 100644 index 00000000000..d00ae3ce4fc --- /dev/null +++ b/dotnet/src/AutoGen/GlobalUsing.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// globalUsing.cs + +global using AutoGen.Core; diff --git a/dotnet/src/AutoGen/Core/Middleware/HumanInputMiddleware.cs b/dotnet/src/AutoGen/Middleware/HumanInputMiddleware.cs similarity index 98% rename from dotnet/src/AutoGen/Core/Middleware/HumanInputMiddleware.cs rename to dotnet/src/AutoGen/Middleware/HumanInputMiddleware.cs index 7bfd29c3560..1a742b11c79 100644 --- a/dotnet/src/AutoGen/Core/Middleware/HumanInputMiddleware.cs +++ b/dotnet/src/AutoGen/Middleware/HumanInputMiddleware.cs @@ -7,8 +7,7 @@ using System.Threading; using System.Threading.Tasks; -namespace AutoGen.Core.Middleware; - +namespace AutoGen; /// /// the middleware to get human input diff --git a/dotnet/test/AutoGen.SourceGenerator.Tests/GlobalUsing.cs b/dotnet/test/AutoGen.SourceGenerator.Tests/GlobalUsing.cs new file mode 100644 index 00000000000..d00ae3ce4fc --- /dev/null +++ b/dotnet/test/AutoGen.SourceGenerator.Tests/GlobalUsing.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// globalUsing.cs + +global using AutoGen.Core; diff --git a/dotnet/test/AutoGen.SourceGenerator.Tests/TopLevelStatementFunctionExample.cs b/dotnet/test/AutoGen.SourceGenerator.Tests/TopLevelStatementFunctionExample.cs index bc9e625e7cd..bbe3121509e 100644 --- a/dotnet/test/AutoGen.SourceGenerator.Tests/TopLevelStatementFunctionExample.cs +++ b/dotnet/test/AutoGen.SourceGenerator.Tests/TopLevelStatementFunctionExample.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // TopLevelStatementFunctionExample.cs -using AutoGen; - public partial class TopLevelStatementFunctionExample { [Function] diff --git a/dotnet/test/AutoGen.Tests/GlobalUsing.cs b/dotnet/test/AutoGen.Tests/GlobalUsing.cs new file mode 100644 index 00000000000..d00ae3ce4fc --- /dev/null +++ b/dotnet/test/AutoGen.Tests/GlobalUsing.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// globalUsing.cs + +global using AutoGen.Core; diff --git a/dotnet/test/AutoGen.Tests/MiddlewareTest.cs b/dotnet/test/AutoGen.Tests/MiddlewareTest.cs index dea20e8ccb4..6c1c89a33c1 100644 --- a/dotnet/test/AutoGen.Tests/MiddlewareTest.cs +++ b/dotnet/test/AutoGen.Tests/MiddlewareTest.cs @@ -6,7 +6,6 @@ using System.Linq; using System.Text.Json; using System.Threading.Tasks; -using AutoGen.Core.Middleware; using Azure.AI.OpenAI; using FluentAssertions; using Xunit; From 5ed491d6abe6902cd3039bd60b57eed70a22ac32 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 1 Mar 2024 12:38:34 -0800 Subject: [PATCH 17/27] update installation --- dotnet/website/articles/Installation.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/dotnet/website/articles/Installation.md b/dotnet/website/articles/Installation.md index 98f04946151..1aa0f1b4ceb 100644 --- a/dotnet/website/articles/Installation.md +++ b/dotnet/website/articles/Installation.md @@ -46,3 +46,20 @@ Once you finishing adding AutoGen feed, you can consume AutoGen packages in your ``` + +### Package overview +AutoGen.Net provides the following packages, you can choose to install one or more of them based on your needs: + +- `AutoGen`: The one-in-all package, which includes all the core features of AutoGen like `AssistantAgent` and `AutoGen.SourceGenerator`, plus intergration over popular platforms like openai, semantic kernel and LM Studio. +- `AutoGen.Core`: The core package, this package provides the abstraction for message type, agent and group chat. +- `AutoGen.OpenAI`: This package provides the integration agents over openai models. +- `AutoGen.LMStudio`: This package provides the integration agents from LM Studio. +- `AutoGen.SemanticKernel`: This package provides the integration agents over semantic kernel. +- `AutoGen.SourceGenerator`: This package carries a source generator that adds support for type-safe function definition generation. +- `AutoGen.DotnetInteractive`: This packages carries dotnet interactive support to execute dotnet code snippet. + +#### Help me choose +- If you just want to install one package and enjoy the core features of AutoGen, choose `AutoGen`. +- If you want to leverage AutoGen's abstraction only and want to avoid introducing any other dependencies, like `Azure.AI.OpenAI` or `Semantic Kernel`, choose `AutoGen.Core`. +- If you want to use AutoGen with openai, choose `AutoGen.OpenAI`, similarly, choose `AutoGen.LMStudio` or `AutoGen.SemanticKernel` if you want to use agents from LM Studio or semantic kernel. +- If you just want the type-safe source generation for function call and don't want any other features, which even include the AutoGen's abstraction, choose `AutoGen.SourceGenerator`. From 273b283c9d23f96b035b160c9b396f66d0e03e13 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 1 Mar 2024 15:23:26 -0800 Subject: [PATCH 18/27] refactor streamingAgent by adding StreamingMessage type --- .../src/AutoGen.Core/Agent/IStreamingAgent.cs | 2 +- .../Agent/MiddlewareStreamingAgent.cs | 6 +- .../Extension/MiddlewareExtension.cs | 4 +- dotnet/src/AutoGen.Core/Message/IMessage.cs | 12 +- .../AutoGen.Core/Message/MessageEnvelope.cs | 2 +- .../src/AutoGen.Core/Message/TextMessage.cs | 40 ++- .../AutoGen.Core/Message/ToolCallMessage.cs | 42 +++ .../Middleware/DelegateStreamingMiddleware.cs | 4 +- .../Middleware/FunctionCallMiddleware.cs | 171 +++++++++--- .../Middleware/IStreamingMiddleware.cs | 2 +- dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs | 115 +++++++++ .../{ => Agent}/OpenAIClientAgent.cs | 24 +- .../Extension/MessageExtension.cs | 124 --------- dotnet/src/AutoGen.OpenAI/GPTAgent.cs | 243 ------------------ .../Middleware/OpenAIMessageConnector.cs | 67 ++++- .../SemanticKernelAgent.cs | 2 +- dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 34 ++- 17 files changed, 455 insertions(+), 439 deletions(-) create mode 100644 dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs rename dotnet/src/AutoGen.OpenAI/{ => Agent}/OpenAIClientAgent.cs (78%) delete mode 100644 dotnet/src/AutoGen.OpenAI/GPTAgent.cs diff --git a/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs index 3fa121a7b08..f4004b1397b 100644 --- a/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs @@ -12,7 +12,7 @@ namespace AutoGen.Core; /// public interface IStreamingAgent : IAgent { - public Task> GenerateStreamingReplyAsync( + public Task> GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default); diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs index 3aaba4da61c..5470e9e13ae 100644 --- a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs @@ -40,7 +40,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, G throw new NotImplementedException("Streaming agent does not support non-streaming reply."); } - public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { var agent = _agent; foreach (var middleware in _middlewares) @@ -56,7 +56,7 @@ public void Use(IStreamingMiddleware middleware) _middlewares.Add(middleware); } - public void Use(Func>> func, string? middlewareName = null) + public void Use(Func>> func, string? middlewareName = null) { _middlewares.Add(new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func))); } @@ -79,7 +79,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, G throw new NotImplementedException("Streaming agent does not support non-streaming reply."); } - public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { // TODO // fix this diff --git a/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs index 6008b927605..5f951122a08 100644 --- a/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs @@ -190,7 +190,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// public static MiddlewareStreamingAgent RegisterStreamingMiddleware( this TAgent agent, - Func>> func, + Func>> func, string? middlewareName = null) where TAgent : IStreamingAgent { @@ -205,7 +205,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// public static MiddlewareStreamingAgent RegisterStreamingMiddleware( this MiddlewareStreamingAgent agent, - Func>> func, + Func>> func, string? middlewareName = null) where TAgent : IStreamingAgent { diff --git a/dotnet/src/AutoGen.Core/Message/IMessage.cs b/dotnet/src/AutoGen.Core/Message/IMessage.cs index 24d8e383875..7b48f4f0d63 100644 --- a/dotnet/src/AutoGen.Core/Message/IMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/IMessage.cs @@ -33,12 +33,20 @@ namespace AutoGen.Core; /// /// /// -public interface IMessage +public interface IMessage : IStreamingMessage +{ +} + +public interface IMessage : IMessage, IStreamingMessage +{ +} + +public interface IStreamingMessage { string? From { get; set; } } -public interface IMessage : IMessage +public interface IStreamingMessage : IStreamingMessage { T Content { get; } } diff --git a/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs b/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs index a6bad770311..2646174b1ff 100644 --- a/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs +++ b/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs @@ -5,7 +5,7 @@ namespace AutoGen.Core; -public class MessageEnvelope : IMessage +public class MessageEnvelope : IMessage, IStreamingMessage { public MessageEnvelope(T content, string? from = null, IDictionary? metadata = null) { diff --git a/dotnet/src/AutoGen.Core/Message/TextMessage.cs b/dotnet/src/AutoGen.Core/Message/TextMessage.cs index 9c9f8c1c828..b59ddfb9a57 100644 --- a/dotnet/src/AutoGen.Core/Message/TextMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/TextMessage.cs @@ -3,7 +3,7 @@ namespace AutoGen.Core; -public class TextMessage : IMessage +public class TextMessage : IMessage, IStreamingMessage { public TextMessage(Role role, string content, string? from = null) { @@ -12,6 +12,28 @@ public TextMessage(Role role, string content, string? from = null) this.From = from; } + public TextMessage(TextMessageUpdate update) + { + this.Content = update.Content; + this.Role = update.Role; + this.From = update.From; + } + + public void Update(TextMessageUpdate update) + { + if (update.Role != this.Role) + { + throw new System.ArgumentException("Role mismatch", nameof(update)); + } + + if (update.From != this.From) + { + throw new System.ArgumentException("From mismatch", nameof(update)); + } + + this.Content = this.Content + update.Content; + } + public Role Role { get; set; } public string Content { get; set; } @@ -23,3 +45,19 @@ public override string ToString() return $"TextMessage({this.Role}, {this.Content}, {this.From})"; } } + +public class TextMessageUpdate : IStreamingMessage +{ + public TextMessageUpdate(Role role, string content, string? from = null) + { + this.Content = content; + this.From = from; + this.Role = role; + } + + public string Content { get; set; } + + public string? From { get; set; } + + public Role Role { get; set; } +} diff --git a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs index 6a6123037ed..8dcd98ea0ec 100644 --- a/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs @@ -48,6 +48,32 @@ public ToolCallMessage(string functionName, string functionArgs, string? from = this.ToolCalls = new List { new ToolCall(functionName, functionArgs) }; } + public ToolCallMessage(ToolCallMessageUpdate update) + { + this.From = update.From; + this.ToolCalls = new List { new ToolCall(update.FunctionName, update.FunctionArgumentUpdate) }; + } + + public void Update(ToolCallMessageUpdate update) + { + // firstly, valid if the update is from the same agent + if (update.From != this.From) + { + throw new System.ArgumentException("From mismatch", nameof(update)); + } + + // if update.FunctionName exists in the tool calls, update the function arguments + var toolCall = this.ToolCalls.FirstOrDefault(tc => tc.FunctionName == update.FunctionName); + if (toolCall is not null) + { + toolCall.FunctionArguments += update.FunctionArgumentUpdate; + } + else + { + this.ToolCalls.Add(new ToolCall(update.FunctionName, update.FunctionArgumentUpdate)); + } + } + public IList ToolCalls { get; set; } public string? From { get; set; } @@ -64,3 +90,19 @@ public override string ToString() return sb.ToString(); } } + +public class ToolCallMessageUpdate : IStreamingMessage +{ + public ToolCallMessageUpdate(string functionName, string functionArgumentUpdate, string? from = null) + { + this.From = from; + this.FunctionName = functionName; + this.FunctionArgumentUpdate = functionArgumentUpdate; + } + + public string? From { get; set; } + + public string FunctionName { get; set; } + + public string FunctionArgumentUpdate { get; set; } +} diff --git a/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs index a366c954193..5499abccf4c 100644 --- a/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs @@ -9,7 +9,7 @@ namespace AutoGen.Core; internal class DelegateStreamingMiddleware : IStreamingMiddleware { - public delegate Task> MiddlewareDelegate( + public delegate Task> MiddlewareDelegate( MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken); @@ -24,7 +24,7 @@ public DelegateStreamingMiddleware(string? name, MiddlewareDelegate middlewareDe public string? Name { get; } - public Task> InvokeAsync( + public Task> InvokeAsync( MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) diff --git a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs index ecff6cf401f..c8a68de5147 100644 --- a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -22,8 +23,13 @@ namespace AutoGen.Core; /// /// If the reply from the inner agent is but the tool calls is not available in this middleware's function map, /// or the reply from the inner agent is not , the original reply from the inner agent will be returned. +/// +/// When used as a streaming middleware, if the streaming reply from the inner agent is or , +/// This middleware will update the message accordingly and invoke the function if the tool call is available in this middleware's function map. +/// If the streaming reply from the inner agent is other types of message, the most recent message will be used to invoke the function. +/// /// -public class FunctionCallMiddleware : IMiddleware +public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware { private readonly IEnumerable? functions; private readonly IDictionary>>? functionMap; @@ -42,34 +48,10 @@ public class FunctionCallMiddleware : IMiddleware public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { - // if the last message is a function call message, invoke the function and return the result instead of sending to the agent. var lastMessage = context.Messages.Last(); if (lastMessage is ToolCallMessage toolCallMessage) { - var toolCallResult = new List(); - var toolCalls = toolCallMessage.ToolCalls; - foreach (var toolCall in toolCalls) - { - var functionName = toolCall.FunctionName; - var functionArguments = toolCall.FunctionArguments; - if (this.functionMap?.TryGetValue(functionName, out var func) is true) - { - var result = await func(functionArguments); - toolCallResult.Add(new ToolCall(functionName, functionArguments, result)); - } - else if (this.functionMap is not null) - { - var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}"; - - toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage)); - } - else - { - throw new InvalidOperationException("FunctionMap is not available"); - } - } - - return new ToolCallResultMessage(toolCallResult, from: agent.Name); + return await this.InvokeToolCallMessagesBeforeInvokingAgentAsync(toolCallMessage, agent); } // combine functions @@ -82,31 +64,140 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, // if the reply is a function call message plus the function's name is available in function map, invoke the function and return the result instead of sending to the agent. if (reply is ToolCallMessage toolCallMsg) { - var toolCallsReply = toolCallMsg.ToolCalls; - var toolCallResult = new List(); - foreach (var toolCall in toolCallsReply) + return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent); + } + + // for all other messages, just return the reply from the agent. + return reply; + } + + public Task> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) + { + return Task.FromResult(this.StreamingInvokeAsync(context, agent, cancellationToken)); + } + + private async IAsyncEnumerable StreamingInvokeAsync( + MiddlewareContext context, + IStreamingAgent agent, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + var lastMessage = context.Messages.Last(); + if (lastMessage is ToolCallMessage toolCallMessage) + { + yield return await this.InvokeToolCallMessagesBeforeInvokingAgentAsync(toolCallMessage, agent); + } + + // combine functions + var options = new GenerateReplyOptions(context.Options ?? new GenerateReplyOptions()); + var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions; + options.Functions = combinedFunctions?.ToArray(); + + IStreamingMessage? initMessage = default; + await foreach (var message in await agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken)) + { + if (message is ToolCallMessageUpdate toolCallMessageUpdate) { - var fName = toolCall.FunctionName; - var fArgs = toolCall.FunctionArguments; - if (this.functionMap?.TryGetValue(fName, out var func) is true) + if (initMessage is null) + { + initMessage = new ToolCallMessage(toolCallMessageUpdate); + } + else if (initMessage is ToolCallMessage toolCall) { - var result = await func(fArgs); - toolCallResult.Add(new ToolCall(fName, fArgs, result)); + toolCall.Update(toolCallMessageUpdate); + } + else + { + throw new InvalidOperationException("The first message is ToolCallMessage, but the update message is not ToolCallMessageUpdate"); } } + else if (message is TextMessageUpdate textMessageUpdate) + { + if (initMessage is null) + { + initMessage = new TextMessage(textMessageUpdate); + } + else if (initMessage is TextMessage textMessage) + { + textMessage.Update(textMessageUpdate); + } + else + { + throw new InvalidOperationException("The first message is TextMessage, but the update message is not TextMessageUpdate"); + } + } + else + { + initMessage = message; + } + + yield return initMessage; + } + + if (initMessage is ToolCallMessage toolCallMsg) + { + yield return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent); + } + else if (initMessage is not null) + { + yield return initMessage; + } + else + { + throw new InvalidOperationException("The agent returns no message."); + } + } - if (toolCallResult.Count() > 0) + private async Task InvokeToolCallMessagesBeforeInvokingAgentAsync(ToolCallMessage toolCallMessage, IAgent agent) + { + var toolCallResult = new List(); + var toolCalls = toolCallMessage.ToolCalls; + foreach (var toolCall in toolCalls) + { + var functionName = toolCall.FunctionName; + var functionArguments = toolCall.FunctionArguments; + if (this.functionMap?.TryGetValue(functionName, out var func) is true) { - var toolCallResultMessage = new ToolCallResultMessage(toolCallResult, from: agent.Name); - return new AggregateMessage(toolCallMsg, toolCallResultMessage, from: agent.Name); + var result = await func(functionArguments); + toolCallResult.Add(new ToolCall(functionName, functionArguments, result)); + } + else if (this.functionMap is not null) + { + var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}"; + + toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage)); } else { - return reply; + throw new InvalidOperationException("FunctionMap is not available"); } } - // for all other messages, just return the reply from the agent. - return reply; + return new ToolCallResultMessage(toolCallResult, from: agent.Name); + } + + private async Task InvokeToolCallMessagesAfterInvokingAgentAsync(ToolCallMessage toolCallMsg, IAgent agent) + { + var toolCallsReply = toolCallMsg.ToolCalls; + var toolCallResult = new List(); + foreach (var toolCall in toolCallsReply) + { + var fName = toolCall.FunctionName; + var fArgs = toolCall.FunctionArguments; + if (this.functionMap?.TryGetValue(fName, out var func) is true) + { + var result = await func(fArgs); + toolCallResult.Add(new ToolCall(fName, fArgs, result)); + } + } + + if (toolCallResult.Count() > 0) + { + var toolCallResultMessage = new ToolCallResultMessage(toolCallResult, from: agent.Name); + return new AggregateMessage(toolCallMsg, toolCallResultMessage, from: agent.Name); + } + else + { + return toolCallMsg; + } } } diff --git a/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs index dc4d98cce7e..b8965dcc41c 100644 --- a/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs @@ -14,7 +14,7 @@ public interface IStreamingMiddleware { public string? Name { get; } - public Task> InvokeAsync( + public Task> InvokeAsync( MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default); diff --git a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs new file mode 100644 index 00000000000..dd2383b148c --- /dev/null +++ b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GPTAgent.cs + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using AutoGen.OpenAI.Middleware; +using Azure.AI.OpenAI; + +namespace AutoGen.OpenAI; + +/// +/// GPT agent that can be used to connect to OpenAI chat models like GPT-3.5, GPT-4, etc. +/// supports the following message types as input: +/// - +/// - +/// - +/// - +/// - +/// - +/// - where T is +/// - where TMessage1 is and TMessage2 is +/// +/// returns the following message types: +/// - +/// - +/// - where TMessage1 is and TMessage2 is +/// +public class GPTAgent : IStreamingAgent +{ + private readonly IDictionary>>? functionMap; + private readonly OpenAIClient openAIClient; + private readonly string? modelName; + private readonly OpenAIClientAgent _innerAgent; + + public GPTAgent( + string name, + string systemMessage, + ILLMConfig config, + float temperature = 0.7f, + int maxTokens = 1024, + IEnumerable? functions = null, + IDictionary>>? functionMap = null) + { + openAIClient = config switch + { + AzureOpenAIConfig azureConfig => new OpenAIClient(new Uri(azureConfig.Endpoint), new Azure.AzureKeyCredential(azureConfig.ApiKey)), + OpenAIConfig openAIConfig => new OpenAIClient(openAIConfig.ApiKey), + _ => throw new ArgumentException($"Unsupported config type {config.GetType()}"), + }; + + modelName = config switch + { + AzureOpenAIConfig azureConfig => azureConfig.DeploymentName, + OpenAIConfig openAIConfig => openAIConfig.ModelId, + _ => throw new ArgumentException($"Unsupported config type {config.GetType()}"), + }; + + _innerAgent = new OpenAIClientAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); + Name = name; + this.functionMap = functionMap; + } + + public GPTAgent( + string name, + string systemMessage, + OpenAIClient openAIClient, + string modelName, + float temperature = 0.7f, + int maxTokens = 1024, + IEnumerable? functions = null, + IDictionary>>? functionMap = null) + { + this.openAIClient = openAIClient; + this.modelName = modelName; + Name = name; + this.functionMap = functionMap; + _innerAgent = new OpenAIClientAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); + } + + public string Name { get; } + + public async Task GenerateReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default) + { + var oaiConnectorMiddleware = new OpenAIMessageConnector(); + var agent = this._innerAgent.RegisterMiddleware(oaiConnectorMiddleware); + if (this.functionMap is not null) + { + var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap); + agent = agent.RegisterMiddleware(functionMapMiddleware); + } + + return await agent.GenerateReplyAsync(messages, options, cancellationToken); + } + + public async Task> GenerateStreamingReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default) + { + var oaiConnectorMiddleware = new OpenAIMessageConnector(); + var agent = this._innerAgent.RegisterStreamingMiddleware(oaiConnectorMiddleware); + if (this.functionMap is not null) + { + var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap); + agent = agent.RegisterStreamingMiddleware(functionMapMiddleware); + } + + return await agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); + } +} diff --git a/dotnet/src/AutoGen.OpenAI/OpenAIClientAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIClientAgent.cs similarity index 78% rename from dotnet/src/AutoGen.OpenAI/OpenAIClientAgent.cs rename to dotnet/src/AutoGen.OpenAI/Agent/OpenAIClientAgent.cs index 94148243642..ac5c8691fc8 100644 --- a/dotnet/src/AutoGen.OpenAI/OpenAIClientAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIClientAgent.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using AutoGen.OpenAI.Extension; @@ -24,6 +25,7 @@ namespace AutoGen.OpenAI; /// /// /// where T is : chat response message. +/// where T is : streaming chat completions update. /// /// /// @@ -67,12 +69,30 @@ public class OpenAIClientAgent : IStreamingAgent return new MessageEnvelope(reply.Value.Choices.First().Message, from: this.Name); } - public Task> GenerateStreamingReplyAsync( + public Task> GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + return Task.FromResult(this.StreamingReplyAsync(messages, options, cancellationToken)); + } + + private async IAsyncEnumerable StreamingReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var settings = this.CreateChatCompletionsOptions(options, messages); + var response = await this.openAIClient.GetChatCompletionsStreamingAsync(settings); + await foreach (var update in response.WithCancellation(cancellationToken)) + { + if (update.ChoiceIndex > 0) + { + throw new InvalidOperationException("Only one choice is supported in streaming response"); + } + + yield return new MessageEnvelope(update, from: this.Name); + } } private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages) diff --git a/dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs b/dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs index 45171091168..92e0f3776f5 100644 --- a/dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs +++ b/dotnet/src/AutoGen.OpenAI/Extension/MessageExtension.cs @@ -12,65 +12,6 @@ public static class MessageExtension { public static string TEXT_CONTENT_TYPE = "text"; public static string IMAGE_CONTENT_TYPE = "image"; - - public static Message ToMessage(this ChatRequestMessage message) - { - if (message is ChatRequestUserMessage userMessage) - { - var msg = new Message(Role.User, userMessage.Content) - { - Value = message, - }; - - if (userMessage.MultimodalContentItems != null) - { - foreach (var item in userMessage.MultimodalContentItems) - { - if (item is ChatMessageTextContentItem textItem) - { - msg.Metadata.Add(new KeyValuePair(TEXT_CONTENT_TYPE, textItem.Text)); - } - else if (item is ChatMessageImageContentItem imageItem) - { - msg.Metadata.Add(new KeyValuePair(IMAGE_CONTENT_TYPE, imageItem.ImageUrl.Url.OriginalString)); - } - } - } - - return msg; - } - else if (message is ChatRequestAssistantMessage assistantMessage) - { - return new Message(Role.Assistant, assistantMessage.Content) - { - Value = message, - FunctionArguments = assistantMessage.FunctionCall?.Arguments, - FunctionName = assistantMessage.FunctionCall?.Name, - From = assistantMessage.Name, - }; - } - else if (message is ChatRequestSystemMessage systemMessage) - { - return new Message(Role.System, systemMessage.Content) - { - Value = message, - From = systemMessage.Name, - }; - } - else if (message is ChatRequestFunctionMessage functionMessage) - { - return new Message(Role.Function, functionMessage.Content) - { - Value = message, - FunctionName = functionMessage.Name, - }; - } - else - { - throw new ArgumentException($"Unknown message type: {message.GetType()}"); - } - } - public static ChatRequestUserMessage ToChatRequestUserMessage(this Message message) { if (message.Value is ChatRequestUserMessage message1) @@ -109,61 +50,6 @@ public static ChatRequestUserMessage ToChatRequestUserMessage(this Message messa throw new ArgumentException("Content is null and metadata is null"); } - public static ChatRequestAssistantMessage ToChatRequestAssistantMessage(this Message message) - { - if (message.Value is ChatRequestAssistantMessage message1) - { - return message1; - } - - var assistantMessage = new ChatRequestAssistantMessage(message.Content ?? string.Empty); - if (message.FunctionName != null && message.FunctionArguments != null) - { - assistantMessage.FunctionCall = new FunctionCall(message.FunctionName, message.FunctionArguments ?? string.Empty); - } - - return assistantMessage; - } - - public static ChatRequestSystemMessage ToChatRequestSystemMessage(this Message message) - { - if (message.Value is ChatRequestSystemMessage message1) - { - return message1; - } - - if (message.Content is null) - { - throw new ArgumentException("Content is null"); - } - - var systemMessage = new ChatRequestSystemMessage(message.Content); - - return systemMessage; - } - - public static ChatRequestFunctionMessage ToChatRequestFunctionMessage(this Message message) - { - if (message.Value is ChatRequestFunctionMessage message1) - { - return message1; - } - - if (message.FunctionName is null) - { - throw new ArgumentException("FunctionName is null"); - } - - if (message.Content is null) - { - throw new ArgumentException("Content is null"); - } - - var functionMessage = new ChatRequestFunctionMessage(message.FunctionName, message.Content); - - return functionMessage; - } - public static IEnumerable ToOpenAIChatRequestMessage(this IAgent agent, IMessage message) { if (message is IMessage oaiMessage) @@ -339,14 +225,4 @@ public static IEnumerable ToOpenAIChatRequestMessage(this IA } } } - - public static IEnumerable ToAutoGenMessages(this IAgent agent, IEnumerable> openaiMessages) - { - throw new NotImplementedException(); - } - - public static IMessage ToAutoGenMessage(ChatRequestMessage openaiMessage, string? from = null) - { - throw new NotImplementedException(); - } } diff --git a/dotnet/src/AutoGen.OpenAI/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI/GPTAgent.cs deleted file mode 100644 index fcb96588d23..00000000000 --- a/dotnet/src/AutoGen.OpenAI/GPTAgent.cs +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// GPTAgent.cs - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using AutoGen.OpenAI.Extension; -using AutoGen.OpenAI.Middleware; -using Azure.AI.OpenAI; - -namespace AutoGen.OpenAI; - -/// -/// GPT agent that can be used to connect to OpenAI chat models like GPT-3.5, GPT-4, etc. -/// supports the following message types as input: -/// - -/// - -/// - -/// - -/// - -/// - -/// - where T is -/// - where TMessage1 is and TMessage2 is -/// -/// returns the following message types: -/// - -/// - -/// - where TMessage1 is and TMessage2 is -/// -public class GPTAgent : IStreamingAgent -{ - private readonly string _systemMessage; - private readonly IEnumerable? _functions; - private readonly float _temperature; - private readonly int _maxTokens = 1024; - private readonly IDictionary>>? functionMap; - private readonly OpenAIClient openAIClient; - private readonly string? modelName; - private readonly OpenAIClientAgent _innerAgent; - - public GPTAgent( - string name, - string systemMessage, - ILLMConfig config, - float temperature = 0.7f, - int maxTokens = 1024, - IEnumerable? functions = null, - IDictionary>>? functionMap = null) - { - openAIClient = config switch - { - AzureOpenAIConfig azureConfig => new OpenAIClient(new Uri(azureConfig.Endpoint), new Azure.AzureKeyCredential(azureConfig.ApiKey)), - OpenAIConfig openAIConfig => new OpenAIClient(openAIConfig.ApiKey), - _ => throw new ArgumentException($"Unsupported config type {config.GetType()}"), - }; - - modelName = config switch - { - AzureOpenAIConfig azureConfig => azureConfig.DeploymentName, - OpenAIConfig openAIConfig => openAIConfig.ModelId, - _ => throw new ArgumentException($"Unsupported config type {config.GetType()}"), - }; - - _innerAgent = new OpenAIClientAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); - _systemMessage = systemMessage; - _functions = functions; - Name = name; - _temperature = temperature; - _maxTokens = maxTokens; - this.functionMap = functionMap; - } - - public GPTAgent( - string name, - string systemMessage, - OpenAIClient openAIClient, - string modelName, - float temperature = 0.7f, - int maxTokens = 1024, - IEnumerable? functions = null, - IDictionary>>? functionMap = null) - { - this.openAIClient = openAIClient; - this.modelName = modelName; - _systemMessage = systemMessage; - _functions = functions; - Name = name; - _temperature = temperature; - _maxTokens = maxTokens; - this.functionMap = functionMap; - _innerAgent = new OpenAIClientAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); - } - - public string Name { get; } - - public async Task GenerateReplyAsync( - IEnumerable messages, - GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default) - { - var oaiConnectorMiddleware = new OpenAIMessageConnector(); - var agent = this._innerAgent.RegisterMiddleware(oaiConnectorMiddleware); - if (this.functionMap is not null) - { - var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap); - agent = agent.RegisterMiddleware(functionMapMiddleware); - } - - return await agent.GenerateReplyAsync(messages, options, cancellationToken); - } - - public async Task> GenerateStreamingReplyAsync( - IEnumerable messages, - GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default) - { - var settings = this.CreateChatCompletionsOptions(options, messages); - var response = await this.openAIClient.GetChatCompletionsStreamingAsync(settings, cancellationToken); - return this.ProcessResponse(response); - } - - private async IAsyncEnumerable ProcessResponse(StreamingResponse response) - { - var content = string.Empty; - string? functionName = default; - string? functionArguments = default; - await foreach (var chunk in response) - { - if (chunk?.FunctionName is not null) - { - functionName = chunk.FunctionName; - } - - if (chunk?.FunctionArgumentsUpdate is not null) - { - if (functionArguments is null) - { - functionArguments = chunk.FunctionArgumentsUpdate; - } - else - { - functionArguments += chunk.FunctionArgumentsUpdate; - } - } - - if (chunk?.ContentUpdate is not null) - { - if (content is null) - { - content = chunk.ContentUpdate; - } - else - { - content += chunk.ContentUpdate; - } - } - - // case 1: plain text content - // in this case we yield the message - if (content is not null && functionName is null) - { - var msg = new TextMessage(Role.Assistant, content, from: this.Name); - - yield return msg; - continue; - } - - // case 2: function call - // in this case, we yield the message once after function name is available and function args has been updated - if (functionName is not null && functionArguments is not null) - { - var msg = new ToolCallMessage(functionName, functionArguments, from: this.Name); - yield return msg; - - if (functionMap is not null && chunk?.FinishReason is not null && chunk.FinishReason == CompletionsFinishReason.FunctionCall) - { - // call the function - if (this.functionMap.TryGetValue(functionName, out var func)) - { - var result = await func(functionArguments); - yield return new ToolCallResultMessage(result, functionName, functionArguments, from: this.Name); - } - else - { - var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}"; - yield return new ToolCallResultMessage(errorMessage, functionName, functionArguments, from: this.Name); - } - } - - continue; - } - } - } - - private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable messages) - { - var oaiMessages = this.ProcessMessages(messages); - var settings = new ChatCompletionsOptions(this.modelName, oaiMessages) - { - MaxTokens = options?.MaxToken ?? _maxTokens, - Temperature = options?.Temperature ?? _temperature, - }; - - var openAIFunctions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition()); - var functions = openAIFunctions ?? _functions; - if (functions is not null && functions.Count() > 0) - { - foreach (var f in functions) - { - settings.Functions.Add(f); - } - //foreach (var f in functions) - //{ - // settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f)); - //} - } - - if (options?.StopSequence is var sequence && sequence is { Length: > 0 }) - { - foreach (var seq in sequence) - { - settings.StopSequences.Add(seq); - } - } - - return settings; - } - - - private IEnumerable ProcessMessages(IEnumerable messages) - { - // add system message if there's no system message in messages - var openAIMessages = messages.SelectMany(m => this.ToOpenAIChatRequestMessage(m)) ?? []; - if (!openAIMessages.Any(m => m is ChatRequestSystemMessage)) - { - openAIMessages = new[] { new ChatRequestSystemMessage(_systemMessage) }.Concat(openAIMessages); - } - - return openAIMessages; - } -} diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIMessageConnector.cs index b94bec86eb8..ad3ac26c638 100644 --- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIMessageConnector.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Azure.AI.OpenAI; @@ -43,9 +44,46 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, return PostProcessMessage(reply); } - public Task> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) + public async Task> InvokeAsync( + MiddlewareContext context, + IStreamingAgent agent, + CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + return InvokeStreamingAsync(context, agent, cancellationToken); + } + + private async IAsyncEnumerable InvokeStreamingAsync( + MiddlewareContext context, + IStreamingAgent agent, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + var chatMessages = ProcessIncomingMessages(agent, context.Messages) + .Select(m => new MessageEnvelope(m)); + var streamingReply = await agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken); + string? currentToolName = null; + await foreach (var reply in streamingReply) + { + if (reply is IStreamingMessage update) + { + if (update.Content.FunctionName is string functionName) + { + currentToolName = functionName; + } + else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate toolCallUpdate && toolCallUpdate.Name is string toolCallName) + { + currentToolName = toolCallName; + } + var postProcessMessage = PostProcessStreamingMessage(update, currentToolName); + if (postProcessMessage != null) + { + yield return postProcessMessage; + } + } + else + { + throw new InvalidOperationException("The type of streaming reply is not supported. Must be one of StreamingChatCompletionsUpdate"); + } + } } public IMessage PostProcessMessage(IMessage message) @@ -64,6 +102,31 @@ public IMessage PostProcessMessage(IMessage message) }; } + public IStreamingMessage? PostProcessStreamingMessage(IStreamingMessage update, string? currentToolName) + { + if (update.Content.ContentUpdate is string contentUpdate) + { + // text message + return new TextMessageUpdate(Role.Assistant, contentUpdate, from: update.From); + } + else if (update.Content.FunctionName is string functionName) + { + return new ToolCallMessageUpdate(functionName, string.Empty, from: update.From); + } + else if (update.Content.FunctionArgumentsUpdate is string functionArgumentsUpdate && currentToolName is string) + { + return new ToolCallMessageUpdate(currentToolName, functionArgumentsUpdate, from: update.From); + } + else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate tooCallUpdate && currentToolName is string) + { + return new ToolCallMessageUpdate(tooCallUpdate.Name ?? currentToolName, tooCallUpdate.ArgumentsUpdate, from: update.From); + } + else + { + return null; + } + } + private IMessage PostProcessMessage(IMessage message) { var chatResponseMessage = message.Content; diff --git a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs index 8a51d874bc8..684a9b4aa5c 100644 --- a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs +++ b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs @@ -92,7 +92,7 @@ public async Task GenerateReplyAsync(IEnumerable messages, G } } - public async Task> GenerateStreamingReplyAsync( + public async Task> GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index 92ce8079461..e1bab5738e4 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -235,23 +235,24 @@ private async Task EchoFunctionCallExecutionTestAsync(IAgent agent) private async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent agent) { - var message = new Message(Role.System, "You are a helpful AI assistant that echo whatever user says"); - var helloWorld = new Message(Role.User, "echo Hello world"); + var message = new TextMessage(Role.System, "You are a helpful AI assistant that echo whatever user says"); + var helloWorld = new TextMessage(Role.User, "echo Hello world"); var option = new GenerateReplyOptions { Temperature = 0, }; - var replyStream = await agent.GenerateStreamingReplyAsync(messages: new Message[] { message, helloWorld }, option); + var replyStream = await agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); var answer = "[ECHO] Hello world"; - IMessage? finalReply = default; + IStreamingMessage? finalReply = default; await foreach (var reply in replyStream) { reply.From.Should().Be(agent.Name); finalReply = reply; } - if (finalReply is ToolCallResultMessage toolCallResultMessage) + if (finalReply is AggregateMessage aggregateMessage) { + var toolCallResultMessage = aggregateMessage.Message2; toolCallResultMessage.ToolCalls.First().Result.Should().Be(answer); toolCallResultMessage.From.Should().Be(agent.Name); toolCallResultMessage.ToolCalls.First().FunctionName.Should().Be(nameof(EchoAsync)); @@ -275,24 +276,29 @@ private async Task UpperCaseTest(IAgent agent) private async Task UpperCaseStreamingTestAsync(IStreamingAgent agent) { - var message = new Message(Role.System, "You are a helpful AI assistant that convert user message to upper case"); - var helloWorld = new Message(Role.User, "a b c d e f g h i j k l m n"); + var message = new TextMessage(Role.System, "You are a helpful AI assistant that convert user message to upper case"); + var helloWorld = new TextMessage(Role.User, "a b c d e f g h i j k l m n"); var option = new GenerateReplyOptions { Temperature = 0, }; - var replyStream = await agent.GenerateStreamingReplyAsync(messages: new Message[] { message, helloWorld }, option); + var replyStream = await agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); var answer = "A B C D E F G H I J K L M N"; TextMessage? finalReply = default; await foreach (var reply in replyStream) { - if (reply is TextMessage textMessage) + if (reply is TextMessageUpdate update) { - textMessage.From.Should().Be(agent.Name); - - // the content should be part of the answer - textMessage.Content.Should().Be(answer.Substring(0, textMessage.Content!.Length)); - finalReply = textMessage; + update.From.Should().Be(agent.Name); + + if (finalReply is null) + { + finalReply = new TextMessage(update); + } + else + { + finalReply.Update(update); + } continue; } From 40265e68e879b857910da17bfc1be3661fd383ec Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 1 Mar 2024 15:35:43 -0800 Subject: [PATCH 19/27] update sample --- .../AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs | 2 +- dotnet/sample/AutoGen.BasicSamples/Program.cs | 2 +- dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs | 6 ++---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs index 210362d54f5..3f543fc8bde 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs @@ -27,7 +27,7 @@ public static async Task RunAsync() }) .RegisterPostProcess(async (_, reply, _) => { - if (reply.GetContent()?.Contains("TERMINATE") is true) + if (reply.GetContent()?.ToLower().Contains("terminate") is true) { return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, from: reply.From); } diff --git a/dotnet/sample/AutoGen.BasicSamples/Program.cs b/dotnet/sample/AutoGen.BasicSamples/Program.cs index 058cd5fa044..fb0bacbb5a1 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Program.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Program.cs @@ -1,4 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Program.cs -await Example05_Dalle_And_GPT4V.RunAsync(); +await Example02_TwoAgent_MathChat.RunAsync(); diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs index 5470e9e13ae..60d2b2638b2 100644 --- a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs @@ -37,7 +37,7 @@ public MiddlewareStreamingAgent(IStreamingAgent agent, string? name = null, IEnu public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - throw new NotImplementedException("Streaming agent does not support non-streaming reply."); + return await _agent.GenerateReplyAsync(messages, options, cancellationToken); } public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) @@ -76,13 +76,11 @@ public DelegateStreamingAgent(IStreamingMiddleware middleware, IStreamingAgent n public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - throw new NotImplementedException("Streaming agent does not support non-streaming reply."); + return await innerAgent.GenerateReplyAsync(messages, options, cancellationToken); } public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - // TODO - // fix this var context = new MiddlewareContext(messages, options); return middleware.InvokeAsync(context, innerAgent, cancellationToken); } From a392b232d6ba955c0571a2a548d5c0938acca7fb Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 1 Mar 2024 16:39:03 -0800 Subject: [PATCH 20/27] update samples --- .../Example01_AssistantAgent.cs | 14 ++-- .../Example02_TwoAgent_MathChat.cs | 11 ++- .../Example03_Agent_FunctionCall.cs | 35 ++++++---- ...Example04_Dynamic_GroupChat_Coding_Task.cs | 55 ++++++--------- .../Example06_UserProxyAgent.cs | 15 ++-- ...7_Dynamic_GroupChat_Calculate_Fibonacci.cs | 43 ++++-------- .../AutoGen.BasicSamples/LLMConfiguration.cs | 40 +++++++++++ dotnet/sample/AutoGen.BasicSamples/Program.cs | 4 +- .../Middleware/PrintMessageMiddleware.cs | 68 +++++++++++++++++-- 9 files changed, 174 insertions(+), 111 deletions(-) create mode 100644 dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs diff --git a/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs index 5a7c9612cea..54731263351 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example01_AssistantAgent.cs @@ -2,8 +2,8 @@ // Example01_AssistantAgent.cs using AutoGen; +using AutoGen.BasicSample; using FluentAssertions; -using autogen = AutoGen.LLMConfigAPI; /// /// This example shows the basic usage of class. @@ -12,13 +12,11 @@ public static class Example01_AssistantAgent { public static async Task RunAsync() { - // get OpenAI Key and create config - var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable."); - var llmConfig = autogen.GetOpenAIConfigList(openAIKey, new[] { "gpt-3.5-turbo" }); + var gpt35 = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo(); var config = new ConversableAgentConfig { Temperature = 0, - ConfigList = llmConfig, + ConfigList = [gpt35], }; // create assistant agent @@ -31,8 +29,7 @@ public static async Task RunAsync() // talk to the assistant agent var reply = await assistantAgent.SendAsync("hello world"); reply.Should().BeOfType(); - var textReply = (TextMessage)reply; - textReply.Content.Should().Be("HELLO WORLD"); + reply.GetContent().Should().Be("HELLO WORLD"); // to carry on the conversation, pass the previous conversation history to the next call var conversationHistory = new List @@ -43,7 +40,6 @@ public static async Task RunAsync() reply = await assistantAgent.SendAsync("hello world again", conversationHistory); reply.Should().BeOfType(); - textReply = (TextMessage)reply; - textReply.Content?.Should().Be("HELLO WORLD AGAIN"); + reply.GetContent().Should().Be("HELLO WORLD AGAIN"); } } diff --git a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs index 3f543fc8bde..3466455e536 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs @@ -2,16 +2,15 @@ // Example02_TwoAgent_MathChat.cs using AutoGen; +using AutoGen.BasicSample; using FluentAssertions; -using autogen = AutoGen.LLMConfigAPI; public static class Example02_TwoAgent_MathChat { public static async Task RunAsync() { #region code_snippet_1 - // get OpenAI Key and create config - var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable."); - var llmConfig = autogen.GetOpenAIConfigList(openAIKey, new[] { "gpt-3.5-turbo" }); + // get gpt-3.5-turbo config + var gpt35 = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo(); // create teacher agent // teacher agent will create math questions @@ -23,7 +22,7 @@ public static async Task RunAsync() llmConfig: new ConversableAgentConfig { Temperature = 0, - ConfigList = llmConfig, + ConfigList = [gpt35], }) .RegisterPostProcess(async (_, reply, _) => { @@ -44,7 +43,7 @@ public static async Task RunAsync() llmConfig: new ConversableAgentConfig { Temperature = 0, - ConfigList = llmConfig, + ConfigList = [gpt35], }) .RegisterPrintFormatMessageHook(); diff --git a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs index f26bc0fa139..af55d63735c 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example03_Agent_FunctionCall.cs @@ -2,8 +2,8 @@ // Example03_Agent_FunctionCall.cs using AutoGen; +using AutoGen.BasicSample; using FluentAssertions; -using autogen = AutoGen.LLMConfigAPI; /// /// This example shows how to add type-safe function call to an agent. @@ -41,11 +41,10 @@ public async Task CalculateTax(int price, float taxRate) return $"tax is {price * taxRate}"; } - public async Task RunAsync() + public static async Task RunAsync() { - // get OpenAI Key and create config - var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable."); - var llmConfig = autogen.GetOpenAIConfigList(openAIKey, new[] { "gpt-3.5-turbo" }); // the version of GPT needs to support function call, a.k.a later than 0613 + var instance = new Example03_Agent_FunctionCall(); + var gpt35 = LLMConfiguration.GetOpenAIGPT3_5_Turbo(); // AutoGen makes use of AutoGen.SourceGenerator to automatically generate FunctionDefinition and FunctionCallWrapper for you. // The FunctionDefinition will be created based on function signature and XML documentation. @@ -53,12 +52,12 @@ public async Task RunAsync() var config = new ConversableAgentConfig { Temperature = 0, - ConfigList = llmConfig, + ConfigList = [gpt35], FunctionContracts = new[] { - ConcatStringFunctionContract, - UpperCaseFunctionContract, - CalculateTaxFunctionContract, + instance.ConcatStringFunctionContract, + instance.UpperCaseFunctionContract, + instance.CalculateTaxFunctionContract, }, }; @@ -68,23 +67,29 @@ public async Task RunAsync() llmConfig: config, functionMap: new Dictionary>> { - { nameof(ConcatString), this.ConcatStringWrapper }, - { nameof(UpperCase), this.UpperCaseWrapper }, - { nameof(CalculateTax), this.CalculateTaxWrapper }, + { nameof(ConcatString), instance.ConcatStringWrapper }, + { nameof(UpperCase), instance.UpperCaseWrapper }, + { nameof(CalculateTax), instance.CalculateTaxWrapper }, }) .RegisterPrintFormatMessageHook(); // talk to the assistant agent var upperCase = await agent.SendAsync("convert to upper case: hello world"); - upperCase.Should().BeOfType>(); upperCase.GetContent()?.Should().Be("HELLO WORLD"); + upperCase.Should().BeOfType>(); + upperCase.GetToolCalls().Should().HaveCount(1); + upperCase.GetToolCalls().First().FunctionName.Should().Be(nameof(UpperCase)); var concatString = await agent.SendAsync("concatenate strings: a, b, c, d, e"); - concatString.Should().BeOfType>(); concatString.GetContent()?.Should().Be("a b c d e"); + concatString.Should().BeOfType>(); + concatString.GetToolCalls().Should().HaveCount(1); + concatString.GetToolCalls().First().FunctionName.Should().Be(nameof(ConcatString)); var calculateTax = await agent.SendAsync("calculate tax: 100, 0.1"); - calculateTax.Should().BeOfType>(); calculateTax.GetContent().Should().Be("tax is 10"); + calculateTax.Should().BeOfType>(); + calculateTax.GetToolCalls().Should().HaveCount(1); + calculateTax.GetToolCalls().First().FunctionName.Should().Be(nameof(CalculateTax)); } } diff --git a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs index 7539c8bcf3d..72f98555095 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs @@ -2,9 +2,10 @@ // Example04_Dynamic_GroupChat_Coding_Task.cs using AutoGen; +using AutoGen.BasicSample; using AutoGen.DotnetInteractive; +using AutoGen.OpenAI; using FluentAssertions; -using autogen = AutoGen.LLMConfigAPI; public partial class Example04_Dynamic_GroupChat_Coding_Task { @@ -17,7 +18,7 @@ public static async Task RunAsync() if (!Directory.Exists(workDir)) Directory.CreateDirectory(workDir); - var service = new InteractiveService(workDir); + using var service = new InteractiveService(workDir); var dotnetInteractiveFunctions = new DotnetInteractiveFunction(service); var result = Path.Combine(workDir, "result.txt"); @@ -26,26 +27,19 @@ public static async Task RunAsync() await service.StartAsync(workDir, default); - // get OpenAI Key and create config - var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable."); - var gptConfig = autogen.GetOpenAIConfigList(openAIKey, ["gpt-4"]); + var gptConfig = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo(); - var helperAgent = new AssistantAgent( + var helperAgent = new GPTAgent( name: "helper", systemMessage: "You are a helpful AI assistant", - llmConfig: new ConversableAgentConfig - { - Temperature = 0, - ConfigList = gptConfig, - }); + temperature: 0f, + config: gptConfig); - var groupAdmin = new AssistantAgent( + var groupAdmin = new GPTAgent( name: "groupAdmin", - llmConfig: new ConversableAgentConfig - { - Temperature = 0, - ConfigList = gptConfig, - }); + systemMessage: "You are the admin of the group chat", + temperature: 0f, + config: gptConfig); var userProxy = new UserProxyAgent(name: "user", defaultReply: GroupChatExtension.TERMINATE, humanInputMode: HumanInputMode.NEVER) .RegisterPrintFormatMessageHook(); @@ -99,7 +93,7 @@ public static async Task RunAsync() llmConfig: new ConversableAgentConfig { Temperature = 0, - ConfigList = gptConfig, + ConfigList = [gptConfig], }) .RegisterPrintFormatMessageHook(); @@ -108,7 +102,7 @@ public static async Task RunAsync() // The dotnet coder write dotnet code to resolve the task. // The code reviewer review the code block from coder's reply. // The nuget agent install nuget packages if there's any. - var coderAgent = new AssistantAgent( + var coderAgent = new GPTAgent( name: "coder", systemMessage: @"You act as dotnet coder, you write dotnet code to resolve task. Once you finish writing code, ask runner to run the code for you. @@ -130,11 +124,8 @@ public static async Task RunAsync() Here's some externel information - The link to mlnet repo is: https://github.com/dotnet/machinelearning. you don't need a token to use github pr api. Make sure to include a User-Agent header, otherwise github will reject it. ", - llmConfig: new ConversableAgentConfig - { - Temperature = 0.4f, - ConfigList = gptConfig, - }) + config: gptConfig, + temperature: 0.4f) .RegisterPrintFormatMessageHook(); // code reviewer agent will review if code block from coder's reply satisfy the following conditions: @@ -142,7 +133,7 @@ public static async Task RunAsync() // - The code block is csharp code block // - The code block is top level statement // - The code block is not using declaration - var codeReviewAgent = new AssistantAgent( + var codeReviewAgent = new GPTAgent( name: "reviewer", systemMessage: """ You are a code reviewer who reviews code from coder. You need to check if the code satisfy the following conditions: @@ -168,11 +159,8 @@ public static async Task RunAsync() ``` """, - llmConfig: new ConversableAgentConfig - { - Temperature = 0, - ConfigList = gptConfig, - }) + config: gptConfig, + temperature: 0f) .RegisterPrintFormatMessageHook(); // create runner agent @@ -181,12 +169,7 @@ public static async Task RunAsync() // It also truncate the output if the output is too long. var runner = new AssistantAgent( name: "runner", - defaultReply: "No code available, coder, write code please", - llmConfig: new ConversableAgentConfig - { - Temperature = 0, - ConfigList = gptConfig, - }) + defaultReply: "No code available, coder, write code please") .RegisterDotnetCodeBlockExectionHook(interactiveService: service) .RegisterMiddleware(async (msgs, option, agent, ct) => { diff --git a/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs index a98fbcc5c8e..9d87489874b 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example06_UserProxyAgent.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Example06_UserProxyAgent.cs -using autogen = AutoGen.LLMConfigAPI; +using AutoGen.OpenAI; namespace AutoGen.BasicSample; @@ -8,19 +8,12 @@ public static class Example06_UserProxyAgent { public static async Task RunAsync() { - // get OpenAI Key and create config - var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable."); - var llmConfig = autogen.GetOpenAIConfigList(openAIKey, new[] { "gpt-3.5-turbo" }); - var config = new ConversableAgentConfig - { - Temperature = 0, - ConfigList = llmConfig, - }; + var gpt35 = LLMConfiguration.GetOpenAIGPT3_5_Turbo(); - var assistantAgent = new AssistantAgent( + var assistantAgent = new GPTAgent( name: "assistant", systemMessage: "You are an assistant that help user to do some tasks.", - llmConfig: config) + config: gpt35) .RegisterPrintFormatMessageHook(); // set human input mode to ALWAYS so that user always provide input diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs index c46cd20584e..14e9aa8d6b3 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs @@ -6,6 +6,8 @@ using AutoGen; using System.Text; using FluentAssertions; +using AutoGen.BasicSample; +using AutoGen.OpenAI; public partial class Example07_Dynamic_GroupChat_Calculate_Fibonacci { @@ -44,7 +46,6 @@ public struct CodeReviewResult } #endregion reviewer_function - public static async Task RunAsync() { var functions = new Example07_Dynamic_GroupChat_Calculate_Fibonacci(); @@ -53,28 +54,19 @@ public static async Task RunAsync() if (!Directory.Exists(workDir)) Directory.CreateDirectory(workDir); - var service = new InteractiveService(workDir); + using var service = new InteractiveService(workDir); var dotnetInteractiveFunctions = new DotnetInteractiveFunction(service); await service.StartAsync(workDir, default); - // get OpenAI Key and create config - var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable."); - var gpt3Config = LLMConfigAPI.GetOpenAIConfigList(openAIKey, new[] { "gpt-3.5-turbo" }); + var gpt3Config = LLMConfiguration.GetAzureOpenAIGPT3_5_Turbo(); #region create_reviewer - var reviewer = new AssistantAgent( + var reviewer = new GPTAgent( name: "code_reviewer", systemMessage: @"You review code block from coder", - llmConfig: new ConversableAgentConfig - { - Temperature = 0, - ConfigList = gpt3Config, - FunctionContracts = new[] - { - functions.ReviewCodeBlockFunctionContract, - }, - }, + config: gpt3Config, + functions: [functions.ReviewCodeBlockFunction], functionMap: new Dictionary>>() { { nameof(ReviewCodeBlock), functions.ReviewCodeBlockWrapper }, @@ -153,7 +145,7 @@ public static async Task RunAsync() #endregion create_reviewer #region create_coder - var coder = new AssistantAgent( + var coder = new GPTAgent( name: "coder", systemMessage: @"You act as dotnet coder, you write dotnet code to resolve task. Once you finish writing code, ask runner to run the code for you. @@ -171,11 +163,8 @@ public static async Task RunAsync() ``` If your code is incorrect, runner will tell you the error message. Fix the error and send the code again.", - llmConfig: new ConversableAgentConfig - { - Temperature = 0.4f, - ConfigList = gpt3Config, - }) + config: gpt3Config, + temperature: 0.4f) .RegisterPrintFormatMessageHook(); #endregion create_coder @@ -211,14 +200,11 @@ public static async Task RunAsync() #endregion create_runner #region create_admin - var admin = new AssistantAgent( + var admin = new GPTAgent( name: "admin", systemMessage: "You are group admin, terminate the group chat once task is completed by saying [TERMINATE] plus the final answer", - llmConfig: new ConversableAgentConfig - { - Temperature = 0, - ConfigList = gpt3Config, - }) + temperature: 0, + config: gpt3Config) .RegisterPostProcess(async (_, reply, _) => { if (reply is TextMessage textMessage && textMessage.Content.Contains("TERMINATE") is true) @@ -260,8 +246,7 @@ public static async Task RunAsync() lastMessage.From.Should().Be("admin"); lastMessage.IsGroupChatTerminateMessage().Should().BeTrue(); lastMessage.Should().BeOfType(); - var textMessage = (TextMessage)lastMessage; - textMessage.Content.Should().Contain(the39thFibonacciNumber.ToString()); + lastMessage.GetContent().Should().Contain(the39thFibonacciNumber.ToString()); #endregion start_group_chat } } diff --git a/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs b/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs new file mode 100644 index 00000000000..37c9b0d7ade --- /dev/null +++ b/dotnet/sample/AutoGen.BasicSamples/LLMConfiguration.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// LLMConfiguration.cs + +using AutoGen.OpenAI; + +namespace AutoGen.BasicSample; + +internal static class LLMConfiguration +{ + public static OpenAIConfig GetOpenAIGPT3_5_Turbo() + { + var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable."); + var modelId = "gpt-3.5-turbo"; + return new OpenAIConfig(openAIKey, modelId); + } + + public static OpenAIConfig GetOpenAIGPT4() + { + var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable."); + var modelId = "gpt-4"; + + return new OpenAIConfig(openAIKey, modelId); + } + + public static AzureOpenAIConfig GetAzureOpenAIGPT3_5_Turbo(string deployName = "gpt-35-turbo-16k") + { + var azureOpenAIKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + + return new AzureOpenAIConfig(endpoint, deployName, azureOpenAIKey); + } + + public static AzureOpenAIConfig GetAzureOpenAIGPT4(string deployName = "gpt-4") + { + var azureOpenAIKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + + return new AzureOpenAIConfig(endpoint, deployName, azureOpenAIKey); + } +} diff --git a/dotnet/sample/AutoGen.BasicSamples/Program.cs b/dotnet/sample/AutoGen.BasicSamples/Program.cs index fb0bacbb5a1..665655591ee 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Program.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Program.cs @@ -1,4 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Program.cs -await Example02_TwoAgent_MathChat.RunAsync(); +using AutoGen.BasicSample; + +await Example10_SemanticKernel.RunAsync(); diff --git a/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs index 9738a3a8e4a..9461b697357 100644 --- a/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs @@ -16,12 +16,72 @@ public class PrintMessageMiddleware : IMiddleware public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { - var reply = await agent.GenerateReplyAsync(context.Messages, context.Options, cancellationToken); + if (agent is IStreamingAgent streamingAgent) + { + IMessage? recentUpdate = null; + await foreach (var message in await streamingAgent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken)) + { + if (message is TextMessageUpdate textMessageUpdate) + { + if (recentUpdate is null) + { + // Print from: xxx + Console.WriteLine($"from: {textMessageUpdate.From}"); + recentUpdate = new TextMessage(textMessageUpdate); + Console.Write(textMessageUpdate.Content); + } + else if (recentUpdate is TextMessage recentTextMessage) + { + // Print the content of the message + Console.Write(textMessageUpdate.Content); + recentTextMessage.Update(textMessageUpdate); + } + else + { + throw new InvalidOperationException("The recent update is not a TextMessage"); + } + } + else if (message is ToolCallMessageUpdate toolCallUpdate) + { + if (recentUpdate is null) + { + recentUpdate = new ToolCallMessage(toolCallUpdate); + } + else if (recentUpdate is ToolCallMessage recentToolCallMessage) + { + recentToolCallMessage.Update(toolCallUpdate); + } + else + { + throw new InvalidOperationException("The recent update is not a ToolCallMessage"); + } + } + else if (message is IMessage imessage) + { + recentUpdate = imessage; + } + else + { + throw new InvalidOperationException("The message is not a valid message"); + } + } + Console.WriteLine(); + if (recentUpdate is not null && recentUpdate is not TextMessage) + { + Console.WriteLine(recentUpdate.FormatMessage()); + } - var formattedMessages = reply.FormatMessage(); + return recentUpdate ?? throw new InvalidOperationException("The message is not a valid message"); + } + else + { + var reply = await agent.GenerateReplyAsync(context.Messages, context.Options, cancellationToken); - Console.WriteLine(formattedMessages); + var formattedMessages = reply.FormatMessage(); - return reply; + Console.WriteLine(formattedMessages); + + return reply; + } } } From 6a268a7472fb50f12b14eb433fd086dcbaecc298 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Sat, 2 Mar 2024 11:53:53 -0800 Subject: [PATCH 21/27] update --- dotnet/eng/MetaInfo.props | 2 +- dotnet/eng/Version.props | 4 +- .../Example10_SemanticKernel.cs | 35 ++- .../AutoGen.Core/Message/MessageEnvelope.cs | 23 +- .../AutoGen.Core/Message/MultiModalMessage.cs | 3 +- .../src/AutoGen.Core/Message/TextMessage.cs | 8 +- dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs | 4 +- ...ctor.cs => ChatRequestMessageConnector.cs} | 8 +- .../Middleware/ChatMessageContentConnector.cs | 251 ++++++++++++++++++ .../SemanticKernelAgent.cs | 239 ++++------------- .../test/AutoGen.Tests/OpenAIMessageTests.cs | 8 +- dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 2 +- 12 files changed, 363 insertions(+), 224 deletions(-) rename dotnet/src/AutoGen.OpenAI/Middleware/{OpenAIMessageConnector.cs => ChatRequestMessageConnector.cs} (97%) create mode 100644 dotnet/src/AutoGen.SemanticKernel/Middleware/ChatMessageContentConnector.cs diff --git a/dotnet/eng/MetaInfo.props b/dotnet/eng/MetaInfo.props index 5c149749fd6..b40b9c69723 100644 --- a/dotnet/eng/MetaInfo.props +++ b/dotnet/eng/MetaInfo.props @@ -3,7 +3,7 @@ 0.0.8 AutoGen - https://github.com/microsoft/autogen + https://microsoft.github.io/autogen-for-net/ https://github.com/microsoft/autogen git MIT diff --git a/dotnet/eng/Version.props b/dotnet/eng/Version.props index 40b132a3513..0e88d76d953 100644 --- a/dotnet/eng/Version.props +++ b/dotnet/eng/Version.props @@ -2,8 +2,8 @@ 1.0.0-beta.13 - 1.4.0 - 1.4.0-alpha + 1.5.0 + 1.5.0-alpha 5.0.0 4.3.0 6.0.0 diff --git a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs index 18fd7a905ac..eea6dc62298 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs @@ -3,7 +3,10 @@ using System.ComponentModel; using AutoGen.SemanticKernel.Extension; +using AutoGen.SemanticKernel.Middleware; +using FluentAssertions; using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.OpenAI; namespace AutoGen.BasicSample; @@ -46,16 +49,34 @@ public static async Task RunAsync() }; kernel.Plugins.AddFromObject(new LightPlugin()); - var assistantAgent = kernel - .ToSemanticKernelAgent(name: "assistant", systemMessage: "You control the light", settings) + var skAgent = kernel + .ToSemanticKernelAgent(name: "assistant", systemMessage: "You control the light", settings); + + // Send a message to the skAgent, the skAgent supports the following message types: + // - IMessage + // - (streaming) IMessage + // You can create an IMessage using MessageEnvelope.Create + var chatMessageContent = MessageEnvelope.Create(new ChatMessageContent(AuthorRole.User, "Toggle the light")); + var reply = await skAgent.SendAsync(chatMessageContent); + reply.Should().BeOfType>(); + Console.WriteLine((reply as IMessage).Content.Items[0].As().Text); + + // To support more AutoGen bulit-in IMessage, register skAgent with ChatMessageContentConnector + var connector = new ChatMessageContentConnector(); + var skAgentWithMiddlewares = skAgent + .RegisterMiddleware(connector) .RegisterPrintFormatMessageHook(); + // Now the skAgentWithMiddlewares supports more IMessage types like TextMessage, ImageMessage or MultiModalMessage + // It also register a print format message hook to print the message in a human readable format to the console + await skAgent.SendAsync(chatMessageContent); + await skAgentWithMiddlewares.SendAsync(new TextMessage(Role.User, "Toggle the light")); + + // The more message type an agent support, the more flexible it is to be used in different scenarios + // For example, since the TextMessage is supported, the skAgentWithMiddlewares can be used with user proxy. + var userProxy = new UserProxyAgent("user"); - var userProxyAgent = new UserProxyAgent(name: "user", humanInputMode: HumanInputMode.ALWAYS); - await userProxyAgent.InitiateChatAsync( - receiver: assistantAgent, - message: "Hey assistant, please help me to do some tasks.", - maxRound: 10); + await skAgentWithMiddlewares.InitiateChatAsync(userProxy, "how can I help you today"); } } diff --git a/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs b/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs index 2646174b1ff..f83bea27926 100644 --- a/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs +++ b/dotnet/src/AutoGen.Core/Message/MessageEnvelope.cs @@ -5,18 +5,33 @@ namespace AutoGen.Core; -public class MessageEnvelope : IMessage, IStreamingMessage +public abstract class MessageEnvelope : IMessage, IStreamingMessage { - public MessageEnvelope(T content, string? from = null, IDictionary? metadata = null) + public MessageEnvelope(string? from = null, IDictionary? metadata = null) { - this.Content = content; this.From = from; this.Metadata = metadata ?? new Dictionary(); } - public T Content { get; } + public static MessageEnvelope Create(TContent content, string? from = null, IDictionary? metadata = null) + { + return new MessageEnvelope(content, from, metadata); + } public string? From { get; set; } public IDictionary Metadata { get; set; } } + +public class MessageEnvelope : MessageEnvelope, IMessage, IStreamingMessage +{ + public MessageEnvelope(T content, string? from = null, IDictionary? metadata = null) + : base(from, metadata) + { + this.Content = content; + this.From = from; + this.Metadata = metadata ?? new Dictionary(); + } + + public T Content { get; } +} diff --git a/dotnet/src/AutoGen.Core/Message/MultiModalMessage.cs b/dotnet/src/AutoGen.Core/Message/MultiModalMessage.cs index 3fe1d34383b..9dd2a37af0b 100644 --- a/dotnet/src/AutoGen.Core/Message/MultiModalMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/MultiModalMessage.cs @@ -8,8 +8,9 @@ namespace AutoGen.Core; public class MultiModalMessage : IMessage { - public MultiModalMessage(IEnumerable content, string? from = null) + public MultiModalMessage(Role role, IEnumerable content, string? from = null) { + this.Role = role; this.Content = content; this.From = from; this.Validate(); diff --git a/dotnet/src/AutoGen.Core/Message/TextMessage.cs b/dotnet/src/AutoGen.Core/Message/TextMessage.cs index b59ddfb9a57..ed4d7436dde 100644 --- a/dotnet/src/AutoGen.Core/Message/TextMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/TextMessage.cs @@ -14,7 +14,7 @@ public TextMessage(Role role, string content, string? from = null) public TextMessage(TextMessageUpdate update) { - this.Content = update.Content; + this.Content = update.Content ?? string.Empty; this.Role = update.Role; this.From = update.From; } @@ -31,7 +31,7 @@ public void Update(TextMessageUpdate update) throw new System.ArgumentException("From mismatch", nameof(update)); } - this.Content = this.Content + update.Content; + this.Content = this.Content + update.Content ?? string.Empty; } public Role Role { get; set; } @@ -48,14 +48,14 @@ public override string ToString() public class TextMessageUpdate : IStreamingMessage { - public TextMessageUpdate(Role role, string content, string? from = null) + public TextMessageUpdate(Role role, string? content, string? from = null) { this.Content = content; this.From = from; this.Role = role; } - public string Content { get; set; } + public string? Content { get; set; } public string? From { get; set; } diff --git a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs index dd2383b148c..0011c74f8da 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs @@ -86,7 +86,7 @@ public class GPTAgent : IStreamingAgent GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - var oaiConnectorMiddleware = new OpenAIMessageConnector(); + var oaiConnectorMiddleware = new ChatRequestMessageConnector(); var agent = this._innerAgent.RegisterMiddleware(oaiConnectorMiddleware); if (this.functionMap is not null) { @@ -102,7 +102,7 @@ public class GPTAgent : IStreamingAgent GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - var oaiConnectorMiddleware = new OpenAIMessageConnector(); + var oaiConnectorMiddleware = new ChatRequestMessageConnector(); var agent = this._innerAgent.RegisterStreamingMiddleware(oaiConnectorMiddleware); if (this.functionMap is not null) { diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/ChatRequestMessageConnector.cs similarity index 97% rename from dotnet/src/AutoGen.OpenAI/Middleware/OpenAIMessageConnector.cs rename to dotnet/src/AutoGen.OpenAI/Middleware/ChatRequestMessageConnector.cs index ad3ac26c638..816bcd8b91f 100644 --- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/ChatRequestMessageConnector.cs @@ -12,7 +12,7 @@ namespace AutoGen.OpenAI.Middleware; /// -/// This middleware converts the incoming to before sending to agent. And converts the output to after receiving from agent. +/// This middleware converts the incoming to where T is before sending to agent. And converts the output to after receiving from agent. /// Supported are /// - /// - @@ -23,16 +23,16 @@ namespace AutoGen.OpenAI.Middleware; /// - where T is /// - where TMessage1 is and TMessage2 is /// -public class OpenAIMessageConnector : IMiddleware, IStreamingMiddleware +public class ChatRequestMessageConnector : IMiddleware, IStreamingMiddleware { private bool strictMode = false; - public OpenAIMessageConnector(bool strictMode = false) + public ChatRequestMessageConnector(bool strictMode = false) { this.strictMode = strictMode; } - public string? Name => nameof(OpenAIMessageConnector); + public string? Name => nameof(ChatRequestMessageConnector); public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/AutoGen.SemanticKernel/Middleware/ChatMessageContentConnector.cs b/dotnet/src/AutoGen.SemanticKernel/Middleware/ChatMessageContentConnector.cs new file mode 100644 index 00000000000..3d1931b444c --- /dev/null +++ b/dotnet/src/AutoGen.SemanticKernel/Middleware/ChatMessageContentConnector.cs @@ -0,0 +1,251 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ChatMessageContentConnector.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace AutoGen.SemanticKernel.Middleware; + +/// +/// This middleware converts the incoming to before passing to agent. +/// And converts the reply message from to before returning to the caller. +/// +/// requirement for agent +/// - Input message type: where T is +/// - Reply message type: where T is +/// - (streaming) Reply message type: where T is +/// +/// This middleware supports the following message types: +/// - +/// - +/// - +/// +/// This middleware returns the following message types: +/// - +/// - +/// - +/// - (streaming) +/// +public class ChatMessageContentConnector : IMiddleware, IStreamingMiddleware +{ + public string? Name => nameof(ChatMessageContentConnector); + + public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) + { + var messages = context.Messages; + + var chatMessageContents = ProcessMessage(messages, agent) + .Select(m => new MessageEnvelope(m)); + var reply = await agent.GenerateReplyAsync(chatMessageContents, context.Options, cancellationToken); + + return PostProcessMessage(reply); + } + + public Task> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) + { + return Task.FromResult(InvokeStreamingAsync(context, agent, cancellationToken)); + } + + private async IAsyncEnumerable InvokeStreamingAsync( + MiddlewareContext context, + IStreamingAgent agent, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + var chatMessageContents = ProcessMessage(context.Messages, agent) + .Select(m => new MessageEnvelope(m)); + + await foreach (var reply in await agent.GenerateStreamingReplyAsync(chatMessageContents, context.Options, cancellationToken)) + { + yield return PostProcessStreamingMessage(reply); + } + } + + private IMessage PostProcessMessage(IMessage input) + { + return input switch + { + IMessage messageEnvelope => PostProcessMessage(messageEnvelope), + _ => throw new System.NotImplementedException(), + }; + } + + private IStreamingMessage PostProcessStreamingMessage(IStreamingMessage input) + { + return input switch + { + IStreamingMessage streamingMessage => PostProcessMessage(streamingMessage), + IMessage msg => PostProcessMessage(msg), + _ => throw new System.NotImplementedException(), + }; + } + + private IMessage PostProcessMessage(IMessage messageEnvelope) + { + var chatMessageContent = messageEnvelope.Content; + var items = chatMessageContent.Items.Select(i => i switch + { + TextContent txt => new TextMessage(Role.Assistant, txt.Text!, messageEnvelope.From), + ImageContent img when img.Uri is Uri uri => new ImageMessage(Role.Assistant, uri.ToString(), from: messageEnvelope.From), + ImageContent img when img.Uri is null => throw new InvalidOperationException("ImageContent.Uri is null"), + _ => throw new InvalidOperationException("Unsupported content type"), + }); + + if (items.Count() == 1) + { + return items.First(); + } + else + { + return new MultiModalMessage(Role.Assistant, items, from: messageEnvelope.From); + } + } + + private IStreamingMessage PostProcessMessage(IStreamingMessage streamingMessage) + { + var chatMessageContent = streamingMessage.Content; + if (chatMessageContent.ChoiceIndex > 0) + { + throw new InvalidOperationException("Only one choice is supported in streaming response"); + } + return new TextMessageUpdate(Role.Assistant, chatMessageContent.Content, streamingMessage.From); + } + + private IEnumerable ProcessMessage(IEnumerable messages, IAgent agent) + { + return messages.SelectMany(m => + { + if (m is IMessage chatMessageContent) + { + return [chatMessageContent.Content]; + } + if (m.From == agent.Name) + { + return ProcessMessageForSelf(m); + } + else + { + return ProcessMessageForOthers(m); + } + }); + } + + private IEnumerable ProcessMessageForSelf(IMessage message) + { + return message switch + { + TextMessage textMessage => ProcessMessageForSelf(textMessage), + MultiModalMessage multiModalMessage => ProcessMessageForSelf(multiModalMessage), + Message m => ProcessMessageForSelf(m), + _ => throw new System.NotImplementedException(), + }; + } + + private IEnumerable ProcessMessageForOthers(IMessage message) + { + return message switch + { + TextMessage textMessage => ProcessMessageForOthers(textMessage), + MultiModalMessage multiModalMessage => ProcessMessageForOthers(multiModalMessage), + Message m => ProcessMessageForOthers(m), + _ => throw new System.NotImplementedException(), + }; + } + + private IEnumerable ProcessMessageForSelf(TextMessage message) + { + if (message.Role == Role.System) + { + return [new ChatMessageContent(AuthorRole.System, message.Content)]; + } + else + { + return [new ChatMessageContent(AuthorRole.Assistant, message.Content)]; + } + } + + + private IEnumerable ProcessMessageForOthers(TextMessage message) + { + if (message.Role == Role.System) + { + return [new ChatMessageContent(AuthorRole.System, message.Content)]; + } + else + { + return [new ChatMessageContent(AuthorRole.User, message.Content)]; + } + } + + private IEnumerable ProcessMessageForSelf(MultiModalMessage message) + { + throw new System.InvalidOperationException("MultiModalMessage is not supported in the semantic kernel if it's from self."); + } + + private IEnumerable ProcessMessageForOthers(MultiModalMessage message) + { + var collections = new ChatMessageContentItemCollection(); + foreach (var item in message.Content) + { + if (item is TextMessage textContent) + { + collections.Add(new TextContent(textContent.Content)); + } + else if (item is ImageMessage imageContent) + { + collections.Add(new ImageContent(new Uri(imageContent.Url))); + } + else + { + throw new InvalidOperationException($"Unsupported message type: {item.GetType().Name}"); + } + } + return [new ChatMessageContent(AuthorRole.User, collections)]; + } + + + private IEnumerable ProcessMessageForSelf(Message message) + { + if (message.Role == Role.System) + { + return [new ChatMessageContent(AuthorRole.System, message.Content)]; + } + else if (message.Content is string && message.FunctionName is null && message.FunctionArguments is null) + { + return [new ChatMessageContent(AuthorRole.Assistant, message.Content)]; + } + else if (message.Content is null && message.FunctionName is not null && message.FunctionArguments is not null) + { + throw new System.InvalidOperationException("Function call is not supported in the semantic kernel if it's from self."); + } + else + { + throw new System.InvalidOperationException("Unsupported message type"); + } + } + + private IEnumerable ProcessMessageForOthers(Message message) + { + if (message.Role == Role.System) + { + return [new ChatMessageContent(AuthorRole.System, message.Content)]; + } + else if (message.Content is string && message.FunctionName is null && message.FunctionArguments is null) + { + return [new ChatMessageContent(AuthorRole.User, message.Content)]; + } + else if (message.Content is null && message.FunctionName is not null && message.FunctionArguments is not null) + { + throw new System.InvalidOperationException("Function call is not supported in the semantic kernel if it's from others."); + } + else + { + throw new System.InvalidOperationException("Unsupported message type"); + } + } +} diff --git a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs index 684a9b4aa5c..78603a603d2 100644 --- a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs +++ b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs @@ -6,7 +6,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -using Azure.AI.OpenAI; +using AutoGen.SemanticKernel.Middleware; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.OpenAI; @@ -14,7 +14,19 @@ namespace AutoGen.SemanticKernel; /// -/// The agent that intergrade with the semantic kernel. +/// Semantic Kernel Agent +/// Income message could be one of the following type: +/// +/// where T is +/// +/// +/// Return message could be one of the following type: +/// +/// where T is +/// (streaming) where T is +/// +/// +/// To support more AutoGen built-in , register with . /// public class SemanticKernelAgent : IStreamingAgent { @@ -33,69 +45,40 @@ public class SemanticKernelAgent : IStreamingAgent _systemMessage = systemMessage; _settings = settings; } + public string Name { get; } public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - var chatMessageContents = ProcessMessage(messages); - // if there's no system message in chatMessageContents, add one to the beginning - if (!chatMessageContents.Any(c => c.Role == AuthorRole.System)) - { - chatMessageContents = new[] { new ChatMessageContent(AuthorRole.System, _systemMessage) }.Concat(chatMessageContents); - } - - var chatHistory = new ChatHistory(chatMessageContents); - var option = _settings ?? new OpenAIPromptExecutionSettings - { - Temperature = options?.Temperature ?? 0.7f, - MaxTokens = options?.MaxToken ?? 1024, - StopSequences = options?.StopSequence, - }; - + var chatHistory = BuildChatHistory(messages); + var option = BuildOption(options); var chatService = _kernel.GetRequiredService(); var reply = await chatService.GetChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken); - if (reply.Count() == 1) + if (reply.Count > 1) { - // might be a plain text return or a function call return - var msg = reply.First(); - if (msg is OpenAIChatMessageContent oaiContent) - { - if (oaiContent.Content is string content) - { - return new Message(Role.Assistant, content, this.Name); - } - else if (oaiContent.ToolCalls is { Count: 1 } && oaiContent.ToolCalls.First() is ChatCompletionsFunctionToolCall toolCall) - { - return new Message(Role.Assistant, content: null, this.Name) - { - FunctionName = toolCall.Name, - FunctionArguments = toolCall.Arguments, - }; - } - else - { - // parallel function call is not supported - throw new InvalidOperationException("Unsupported return type, only plain text and function call are supported."); - } - } - else - { - throw new InvalidOperationException("Unsupported return type"); - } - } - else - { - throw new InvalidOperationException("Unsupported return type, multiple messages are not supported."); + throw new InvalidOperationException("ResultsPerPrompt greater than 1 is not supported in this semantic kernel agent"); } + + return new MessageEnvelope(reply.First(), from: this.Name); } public async Task> GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + { + var chatHistory = BuildChatHistory(messages); + var option = BuildOption(options); + var chatService = _kernel.GetRequiredService(); + var response = chatService.GetStreamingChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken); + + return ProcessMessage(response); + } + + private ChatHistory BuildChatHistory(IEnumerable messages) { var chatMessageContents = ProcessMessage(messages); // if there's no system message in chatMessageContents, add one to the beginning @@ -104,172 +87,40 @@ public async Task GenerateReplyAsync(IEnumerable messages, G chatMessageContents = new[] { new ChatMessageContent(AuthorRole.System, _systemMessage) }.Concat(chatMessageContents); } - var chatHistory = new ChatHistory(chatMessageContents); - var option = _settings ?? new OpenAIPromptExecutionSettings + return new ChatHistory(chatMessageContents); + } + + private PromptExecutionSettings BuildOption(GenerateReplyOptions? options) + { + return _settings ?? new OpenAIPromptExecutionSettings { Temperature = options?.Temperature ?? 0.7f, MaxTokens = options?.MaxToken ?? 1024, StopSequences = options?.StopSequence, + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions, + ResultsPerPrompt = 1, }; - - var chatService = _kernel.GetRequiredService(); - var response = chatService.GetStreamingChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken); - - return ProcessMessage(response); } private async IAsyncEnumerable ProcessMessage(IAsyncEnumerable response) { - string? text = null; await foreach (var content in response) { - if (content is OpenAIStreamingChatMessageContent oaiStreamingChatContent && oaiStreamingChatContent.Content is string chunk) - { - text += chunk; - yield return new Message(Role.Assistant, text, this.Name); - } - else + if (content.ChoiceIndex > 0) { - throw new InvalidOperationException("Unsupported return type"); + throw new InvalidOperationException("Only one choice is supported in streaming response"); } - } - if (text is not null) - { - yield return new Message(Role.Assistant, text, this.Name); + yield return new MessageEnvelope(content, from: this.Name); } } private IEnumerable ProcessMessage(IEnumerable messages) { - return messages.SelectMany(m => + return messages.Select(m => m switch { - if (m is IMessage chatMessageContent) - { - return [chatMessageContent.Content]; - } - if (m.From == this.Name) - { - return ProcessMessageForSelf(m); - } - else - { - return ProcessMessageForOthers(m); - } + IMessage cmc => cmc.Content, + _ => throw new ArgumentException("Invalid message type") }); } - - private IEnumerable ProcessMessageForSelf(IMessage message) - { - return message switch - { - TextMessage textMessage => ProcessMessageForSelf(textMessage), - MultiModalMessage multiModalMessage => ProcessMessageForSelf(multiModalMessage), - Message m => ProcessMessageForSelf(m), - _ => throw new System.NotImplementedException(), - }; - } - - private IEnumerable ProcessMessageForOthers(IMessage message) - { - return message switch - { - TextMessage textMessage => ProcessMessageForOthers(textMessage), - MultiModalMessage multiModalMessage => ProcessMessageForOthers(multiModalMessage), - Message m => ProcessMessageForOthers(m), - _ => throw new System.NotImplementedException(), - }; - } - - private IEnumerable ProcessMessageForSelf(TextMessage message) - { - if (message.Role == Role.System) - { - return [new ChatMessageContent(AuthorRole.System, message.Content)]; - } - else - { - return [new ChatMessageContent(AuthorRole.Assistant, message.Content)]; - } - } - - - private IEnumerable ProcessMessageForOthers(TextMessage message) - { - if (message.Role == Role.System) - { - return [new ChatMessageContent(AuthorRole.System, message.Content)]; - } - else - { - return [new ChatMessageContent(AuthorRole.User, message.Content)]; - } - } - - private IEnumerable ProcessMessageForSelf(MultiModalMessage message) - { - throw new System.InvalidOperationException("MultiModalMessage is not supported in the semantic kernel if it's from self."); - } - - private IEnumerable ProcessMessageForOthers(MultiModalMessage message) - { - var collections = new ChatMessageContentItemCollection(); - foreach (var item in message.Content) - { - if (item is TextMessage textContent) - { - collections.Add(new TextContent(textContent.Content)); - } - else if (item is ImageMessage imageContent) - { - collections.Add(new ImageContent(new Uri(imageContent.Url))); - } - else - { - throw new InvalidOperationException($"Unsupported message type: {item.GetType().Name}"); - } - } - return [new ChatMessageContent(AuthorRole.User, collections)]; - } - - - private IEnumerable ProcessMessageForSelf(Message message) - { - if (message.Role == Role.System) - { - return [new ChatMessageContent(AuthorRole.System, message.Content)]; - } - else if (message.Content is string && message.FunctionName is null && message.FunctionArguments is null) - { - return [new ChatMessageContent(AuthorRole.Assistant, message.Content)]; - } - else if (message.Content is null && message.FunctionName is not null && message.FunctionArguments is not null) - { - throw new System.InvalidOperationException("Function call is not supported in the semantic kernel if it's from self."); - } - else - { - throw new System.InvalidOperationException("Unsupported message type"); - } - } - - private IEnumerable ProcessMessageForOthers(Message message) - { - if (message.Role == Role.System) - { - return [new ChatMessageContent(AuthorRole.System, message.Content)]; - } - else if (message.Content is string && message.FunctionName is null && message.FunctionArguments is null) - { - return [new ChatMessageContent(AuthorRole.User, message.Content)]; - } - else if (message.Content is null && message.FunctionName is not null && message.FunctionArguments is not null) - { - throw new System.InvalidOperationException("Function call is not supported in the semantic kernel if it's from others."); - } - else - { - throw new System.InvalidOperationException("Unsupported message type"); - } - } } diff --git a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs index 2e784296fe0..f2e30fe4da1 100644 --- a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs +++ b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs @@ -42,7 +42,7 @@ public void BasicMessageTest() FunctionArguments = "functionArguments", }, new ImageMessage(Role.User, "https://example.com/image.png", "user"), - new MultiModalMessage( + new MultiModalMessage(Role.Assistant, [ new TextMessage(Role.User, "Hello", "user"), new ImageMessage(Role.User, "https://example.com/image.png", "user"), @@ -63,7 +63,7 @@ public void BasicMessageTest() message1: new ToolCallMessage("test", "test", "assistant"), message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"), ]; - var openaiMessageConnectorMiddleware = new OpenAIMessageConnector(); + var openaiMessageConnectorMiddleware = new ChatRequestMessageConnector(); var agent = new EchoAgent("assistant"); var oaiMessages = messages.Select(m => (m, openaiMessageConnectorMiddleware.ProcessIncomingMessages(agent, [m]))); @@ -74,7 +74,7 @@ public void BasicMessageTest() public void ToOpenAIChatRequestMessageTest() { var agent = new EchoAgent("assistant"); - var middleware = new OpenAIMessageConnector(); + var middleware = new ChatRequestMessageConnector(); // user message IMessage message = new TextMessage(Role.User, "Hello", "user"); @@ -253,7 +253,7 @@ public void ToOpenAIChatRequestMessageTest() public void ToOpenAIChatRequestMessageShortCircuitTest() { var agent = new EchoAgent("assistant"); - var middleware = new OpenAIMessageConnector(); + var middleware = new ChatRequestMessageConnector(); ChatRequestMessage[] messages = [ new ChatRequestUserMessage("Hello"), diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index e1bab5738e4..5c8aae38c46 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -71,7 +71,7 @@ public async Task GPTAgentVisionTestAsync() var oaiMessage = new ChatRequestUserMessage( new ChatMessageTextContentItem("which label has the highest inference cost"), new ChatMessageImageContentItem(new Uri(@"https://raw.githubusercontent.com/microsoft/autogen/main/website/blog/2023-04-21-LLM-tuning-math/img/level2algebra.png"))); - var multiModalMessage = new MultiModalMessage( + var multiModalMessage = new MultiModalMessage(Role.User, [ new TextMessage(Role.User, "which label has the highest inference cost", from: "user"), new ImageMessage(Role.User, @"https://raw.githubusercontent.com/microsoft/autogen/main/website/blog/2023-04-21-LLM-tuning-math/img/level2algebra.png", from: "user"), From d301ea322ddf8ac66131b3df870c8014488c0a5b Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Sat, 2 Mar 2024 12:58:14 -0800 Subject: [PATCH 22/27] update --- .../Example10_SemanticKernel.cs | 4 +- .../Agent/MiddlewareStreamingAgent.cs | 76 +++++++++++++++--- .../Extension/MiddlewareExtension.cs | 78 ++++++++++++++++++- 3 files changed, 141 insertions(+), 17 deletions(-) diff --git a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs index eea6dc62298..48af26c3535 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs @@ -64,9 +64,11 @@ public static async Task RunAsync() // To support more AutoGen bulit-in IMessage, register skAgent with ChatMessageContentConnector var connector = new ChatMessageContentConnector(); var skAgentWithMiddlewares = skAgent - .RegisterMiddleware(connector) + .RegisterMiddlewareToStreamingAgent(connector) + .RegisterStreamingMiddleware(connector) .RegisterPrintFormatMessageHook(); + // Now the skAgentWithMiddlewares supports more IMessage types like TextMessage, ImageMessage or MultiModalMessage // It also register a print format message hook to print the message in a human readable format to the console await skAgent.SendAsync(chatMessageContent); diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs index 60d2b2638b2..58ce77aed3a 100644 --- a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs @@ -11,12 +11,22 @@ namespace AutoGen.Core; public class MiddlewareStreamingAgent : IStreamingAgent { private readonly IStreamingAgent _agent; - private readonly List _middlewares = new(); - - public MiddlewareStreamingAgent(IStreamingAgent agent, string? name = null, IEnumerable? middlewares = null) + private readonly List _streamingMiddlewares = new(); + private readonly List _middlewares = new(); + + public MiddlewareStreamingAgent( + IStreamingAgent agent, + string? name = null, + IEnumerable? streamingMiddlewares = null, + IEnumerable? middlewares = null) { _agent = agent; Name = name ?? agent.Name; + if (streamingMiddlewares != null) + { + _streamingMiddlewares.AddRange(streamingMiddlewares); + } + if (middlewares != null) { _middlewares.AddRange(middlewares); @@ -30,20 +40,31 @@ public MiddlewareStreamingAgent(IStreamingAgent agent, string? name = null, IEnu /// public IStreamingAgent Agent => _agent; + /// + /// Get the streaming middlewares. + /// + public IEnumerable StreamingMiddlewares => _streamingMiddlewares; + /// /// Get the middlewares. /// - public IEnumerable Middlewares => _middlewares; + public IEnumerable Middlewares => _middlewares; public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - return await _agent.GenerateReplyAsync(messages, options, cancellationToken); + var agent = _agent; + foreach (var middleware in _middlewares) + { + agent = new DelegateStreamingAgent(middleware, agent); + } + + return await agent.GenerateReplyAsync(messages, options, cancellationToken); } public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { var agent = _agent; - foreach (var middleware in _middlewares) + foreach (var middleware in _streamingMiddlewares) { agent = new DelegateStreamingAgent(middleware, agent); } @@ -51,24 +72,44 @@ public Task> GenerateStreamingReplyAsync(IEn return agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); } - public void Use(IStreamingMiddleware middleware) + public void UseStreaming(IStreamingMiddleware middleware) + { + _streamingMiddlewares.Add(middleware); + } + + public void UseStreaming(Func>> func, string? middlewareName = null) + { + _streamingMiddlewares.Add(new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func))); + } + + public void Use(IMiddleware middleware) { _middlewares.Add(middleware); } - public void Use(Func>> func, string? middlewareName = null) + public void Use(Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, string? middlewareName = null) { - _middlewares.Add(new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func))); + _middlewares.Add(new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) => + { + return await func(context.Messages, context.Options, agent, cancellationToken); + })); } private class DelegateStreamingAgent : IStreamingAgent { - private IStreamingMiddleware middleware; + private IStreamingMiddleware? streamingMiddleware; + private IMiddleware? middleware; private IStreamingAgent innerAgent; public string Name => innerAgent.Name; public DelegateStreamingAgent(IStreamingMiddleware middleware, IStreamingAgent next) + { + this.streamingMiddleware = middleware; + this.innerAgent = next; + } + + public DelegateStreamingAgent(IMiddleware middleware, IStreamingAgent next) { this.middleware = middleware; this.innerAgent = next; @@ -76,13 +117,24 @@ public DelegateStreamingAgent(IStreamingMiddleware middleware, IStreamingAgent n public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - return await innerAgent.GenerateReplyAsync(messages, options, cancellationToken); + if (middleware is null) + { + return await innerAgent.GenerateReplyAsync(messages, options, cancellationToken); + } + + var context = new MiddlewareContext(messages, options); + return await middleware.InvokeAsync(context, innerAgent, cancellationToken); } public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { + if (streamingMiddleware is null) + { + return innerAgent.GenerateStreamingReplyAsync(messages, options, cancellationToken); + } + var context = new MiddlewareContext(messages, options); - return middleware.InvokeAsync(context, innerAgent, cancellationToken); + return streamingMiddleware.InvokeAsync(context, innerAgent, cancellationToken); } } } diff --git a/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs index 5f951122a08..4746ce22268 100644 --- a/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs @@ -51,6 +51,16 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi return middlewareAgent; } + public static MiddlewareAgent RegisterPrintFormatMessageHook(this MiddlewareAgent agent) + where TAgent : IAgent + { + var middleware = new PrintMessageMiddleware(); + var middlewareAgent = new MiddlewareAgent(agent); + middlewareAgent.Use(middleware); + + return middlewareAgent; + } + /// /// Register a post process hook to an agent. The hook will be called before the agent return the reply and after the agent generate the reply. /// This is useful when you want to customize arbitrary behavior before the agent return the reply. @@ -108,6 +118,8 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi return middlewareAgent; } + + /// /// Register a middleware to an existing agent and return a new agent with the middleware. /// @@ -165,7 +177,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi where TAgent : IStreamingAgent { var middlewareAgent = new MiddlewareStreamingAgent(agent); - middlewareAgent.Use(middleware); + middlewareAgent.UseStreaming(middleware); return middlewareAgent; } @@ -179,7 +191,7 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi where TAgent : IStreamingAgent { var copyAgent = new MiddlewareStreamingAgent(agent); - copyAgent.Use(middleware); + copyAgent.UseStreaming(middleware); return copyAgent; } @@ -195,13 +207,13 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi where TAgent : IStreamingAgent { var middlewareAgent = new MiddlewareStreamingAgent(agent); - middlewareAgent.Use(func, middlewareName); + middlewareAgent.UseStreaming(func, middlewareName); return middlewareAgent; } /// - /// Register a middleware to an existing agent and return a new agent with the middleware. + /// Register a streaming middleware to an existing agent and return a new agent with the middleware. /// public static MiddlewareStreamingAgent RegisterStreamingMiddleware( this MiddlewareStreamingAgent agent, @@ -210,8 +222,66 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi where TAgent : IStreamingAgent { var copyAgent = new MiddlewareStreamingAgent(agent); + copyAgent.UseStreaming(func, middlewareName); + + return copyAgent; + } + + /// + /// Register a middleware to an existing streaming agent and return a new agent with the middleware. + /// + public static MiddlewareStreamingAgent RegisterMiddlewareToStreamingAgent( + this TStreamingAgent streamingAgent, + Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, + string? middlewareName = null) + where TStreamingAgent : IStreamingAgent + { + var middlewareAgent = new MiddlewareStreamingAgent(streamingAgent); + middlewareAgent.Use(func, middlewareName); + + return middlewareAgent; + } + + /// + /// Register a middleware to an existing streaming agent and return a new agent with the middleware. + /// + public static MiddlewareStreamingAgent RegisterMiddlewareToStreamingAgent( + this TStreamingAgent streamingAgent, + IMiddleware middleware) + where TStreamingAgent : IStreamingAgent + { + var middlewareAgent = new MiddlewareStreamingAgent(streamingAgent); + middlewareAgent.Use(middleware); + + return middlewareAgent; + } + + /// + /// Register a middleware to an existing streaming agent and return a new agent with the middleware. + /// + public static MiddlewareStreamingAgent RegisterMiddleware( + this MiddlewareStreamingAgent streamingAgent, + Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, + string? middlewareName = null) + where TStreamingAgent : IStreamingAgent + { + var copyAgent = new MiddlewareStreamingAgent(streamingAgent); copyAgent.Use(func, middlewareName); return copyAgent; } + + /// + /// Register a middleware to an existing streaming agent and return a new agent with the middleware. + /// + public static MiddlewareStreamingAgent RegisterMiddleware( + this MiddlewareStreamingAgent streamingAgent, + IMiddleware middleware) + where TStreamingAgent : IStreamingAgent + { + var copyAgent = new MiddlewareStreamingAgent(streamingAgent); + copyAgent.Use(middleware); + + return copyAgent; + } } From f9ff7e6d3166707955f418c8c949072822595ca5 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Sat, 2 Mar 2024 22:12:13 -0800 Subject: [PATCH 23/27] add test --- .../Example10_SemanticKernel.cs | 15 +- .../src/AutoGen.Core/Agent/MiddlewareAgent.cs | 9 ++ .../Extension/MiddlewareExtension.cs | 147 +----------------- .../PrintMessageMiddlewareExtension.cs | 40 +++++ .../Extension/StreamingMiddlewareExtension.cs | 100 ++++++++++++ dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs | 6 +- ...penAIClientAgent.cs => OpenAIChatAgent.cs} | 8 +- .../Middleware/ChatMessageContentConnector.cs | 11 +- dotnet/test/AutoGen.Tests/BasicSampleTest.cs | 13 +- .../test/AutoGen.Tests/MiddlewareAgentTest.cs | 3 + .../test/AutoGen.Tests/OpenAIMessageTests.cs | 1 + .../AutoGen.Tests/SemanticKernelAgentTest.cs | 133 ++++++++++++++++ 12 files changed, 320 insertions(+), 166 deletions(-) create mode 100644 dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs create mode 100644 dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs rename dotnet/src/AutoGen.OpenAI/Agent/{OpenAIClientAgent.cs => OpenAIChatAgent.cs} (95%) create mode 100644 dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs diff --git a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs index 48af26c3535..beda3f71716 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs @@ -61,24 +61,23 @@ public static async Task RunAsync() reply.Should().BeOfType>(); Console.WriteLine((reply as IMessage).Content.Items[0].As().Text); - // To support more AutoGen bulit-in IMessage, register skAgent with ChatMessageContentConnector + // To support more AutoGen built-in IMessage, register skAgent with ChatMessageContentConnector var connector = new ChatMessageContentConnector(); - var skAgentWithMiddlewares = skAgent - .RegisterMiddlewareToStreamingAgent(connector) + var skAgentWithMiddleware = skAgent .RegisterStreamingMiddleware(connector) + .RegisterMiddleware(connector) .RegisterPrintFormatMessageHook(); - - // Now the skAgentWithMiddlewares supports more IMessage types like TextMessage, ImageMessage or MultiModalMessage + // Now the skAgentWithMiddleware supports more IMessage types like TextMessage, ImageMessage or MultiModalMessage // It also register a print format message hook to print the message in a human readable format to the console await skAgent.SendAsync(chatMessageContent); - await skAgentWithMiddlewares.SendAsync(new TextMessage(Role.User, "Toggle the light")); + await skAgentWithMiddleware.SendAsync(new TextMessage(Role.User, "Toggle the light")); // The more message type an agent support, the more flexible it is to be used in different scenarios - // For example, since the TextMessage is supported, the skAgentWithMiddlewares can be used with user proxy. + // For example, since the TextMessage is supported, the skAgentWithMiddleware can be used with user proxy. var userProxy = new UserProxyAgent("user"); - await skAgentWithMiddlewares.InitiateChatAsync(userProxy, "how can I help you today"); + await skAgentWithMiddleware.InitiateChatAsync(userProxy, "how can I help you today"); } } diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs index 71c8bb7e514..0f06a38b828 100644 --- a/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -86,6 +87,14 @@ public void Use(IMiddleware middleware) this.middlewares.Add(middleware); } + public override string ToString() + { + var names = this.Middlewares.Select(m => m.Name ?? "[Unknown middleware]"); + var namesPlusAgentName = names.Append(this.Name); + + return namesPlusAgentName.Aggregate((a, b) => $"{a} -> {b}"); + } + private class DelegateAgent : IAgent { private readonly IAgent innerAgent; diff --git a/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs index 4746ce22268..50f82dd87fb 100644 --- a/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs @@ -38,34 +38,11 @@ public static class MiddlewareExtension }); } - /// - /// Print formatted message to console. - /// - public static MiddlewareAgent RegisterPrintFormatMessageHook(this TAgent agent) - where TAgent : IAgent - { - var middleware = new PrintMessageMiddleware(); - var middlewareAgent = new MiddlewareAgent(agent); - middlewareAgent.Use(middleware); - - return middlewareAgent; - } - - public static MiddlewareAgent RegisterPrintFormatMessageHook(this MiddlewareAgent agent) - where TAgent : IAgent - { - var middleware = new PrintMessageMiddleware(); - var middlewareAgent = new MiddlewareAgent(agent); - middlewareAgent.Use(middleware); - - return middlewareAgent; - } - /// /// Register a post process hook to an agent. The hook will be called before the agent return the reply and after the agent generate the reply. /// This is useful when you want to customize arbitrary behavior before the agent return the reply. /// - /// One example is , which print the formatted message to console before the agent return the reply. + /// One example is , which print the formatted message to console before the agent return the reply. /// /// throw when agent name is null. public static MiddlewareAgent RegisterPostProcess( @@ -85,9 +62,10 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi /// Register a pre process hook to an agent. The hook will be called before the agent generate the reply. This is useful when you want to modify the conversation history before the agent generate the reply. /// /// throw when agent name is null. - public static IAgent RegisterPreProcess( - this IAgent agent, + public static MiddlewareAgent RegisterPreProcess( + this TAgent agent, Func, CancellationToken, Task>> preprocessFunc) + where TAgent : IAgent { return agent.RegisterMiddleware(async (messages, options, agent, ct) => { @@ -167,121 +145,4 @@ public static MiddlewareAgent RegisterPrintFormatMessageHook(thi return copyAgent; } - - /// - /// Register a middleware to an existing agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterStreamingMiddleware( - this TAgent agent, - IStreamingMiddleware middleware) - where TAgent : IStreamingAgent - { - var middlewareAgent = new MiddlewareStreamingAgent(agent); - middlewareAgent.UseStreaming(middleware); - - return middlewareAgent; - } - - /// - /// Register a middleware to an existing agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterStreamingMiddleware( - this MiddlewareStreamingAgent agent, - IStreamingMiddleware middleware) - where TAgent : IStreamingAgent - { - var copyAgent = new MiddlewareStreamingAgent(agent); - copyAgent.UseStreaming(middleware); - - return copyAgent; - } - - - /// - /// Register a middleware to an existing agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterStreamingMiddleware( - this TAgent agent, - Func>> func, - string? middlewareName = null) - where TAgent : IStreamingAgent - { - var middlewareAgent = new MiddlewareStreamingAgent(agent); - middlewareAgent.UseStreaming(func, middlewareName); - - return middlewareAgent; - } - - /// - /// Register a streaming middleware to an existing agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterStreamingMiddleware( - this MiddlewareStreamingAgent agent, - Func>> func, - string? middlewareName = null) - where TAgent : IStreamingAgent - { - var copyAgent = new MiddlewareStreamingAgent(agent); - copyAgent.UseStreaming(func, middlewareName); - - return copyAgent; - } - - /// - /// Register a middleware to an existing streaming agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterMiddlewareToStreamingAgent( - this TStreamingAgent streamingAgent, - Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, - string? middlewareName = null) - where TStreamingAgent : IStreamingAgent - { - var middlewareAgent = new MiddlewareStreamingAgent(streamingAgent); - middlewareAgent.Use(func, middlewareName); - - return middlewareAgent; - } - - /// - /// Register a middleware to an existing streaming agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterMiddlewareToStreamingAgent( - this TStreamingAgent streamingAgent, - IMiddleware middleware) - where TStreamingAgent : IStreamingAgent - { - var middlewareAgent = new MiddlewareStreamingAgent(streamingAgent); - middlewareAgent.Use(middleware); - - return middlewareAgent; - } - - /// - /// Register a middleware to an existing streaming agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterMiddleware( - this MiddlewareStreamingAgent streamingAgent, - Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, - string? middlewareName = null) - where TStreamingAgent : IStreamingAgent - { - var copyAgent = new MiddlewareStreamingAgent(streamingAgent); - copyAgent.Use(func, middlewareName); - - return copyAgent; - } - - /// - /// Register a middleware to an existing streaming agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterMiddleware( - this MiddlewareStreamingAgent streamingAgent, - IMiddleware middleware) - where TStreamingAgent : IStreamingAgent - { - var copyAgent = new MiddlewareStreamingAgent(streamingAgent); - copyAgent.Use(middleware); - - return copyAgent; - } } diff --git a/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs new file mode 100644 index 00000000000..e6d64e168f0 --- /dev/null +++ b/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// PrintMessageMiddlewareExtension.cs + +namespace AutoGen.Core; + +public static class PrintMessageMiddlewareExtension +{ + /// + /// Print formatted message to console. + /// + public static MiddlewareAgent RegisterPrintFormatMessageHook(this TAgent agent) + where TAgent : IAgent + { + var middleware = new PrintMessageMiddleware(); + var middlewareAgent = new MiddlewareAgent(agent); + middlewareAgent.Use(middleware); + + return middlewareAgent; + } + + public static MiddlewareAgent RegisterPrintFormatMessageHook(this MiddlewareAgent agent) + where TAgent : IAgent + { + var middleware = new PrintMessageMiddleware(); + var middlewareAgent = new MiddlewareAgent(agent); + middlewareAgent.Use(middleware); + + return middlewareAgent; + } + + public static MiddlewareStreamingAgent RegisterPrintFormatMessageHook(this MiddlewareStreamingAgent agent) + where TAgent : IStreamingAgent + { + var middleware = new PrintMessageMiddleware(); + var middlewareAgent = new MiddlewareStreamingAgent(agent); + middlewareAgent.Use(middleware); + + return middlewareAgent; + } +} diff --git a/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs new file mode 100644 index 00000000000..b2739bb27a1 --- /dev/null +++ b/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// StreamingMiddlewareExtension.cs + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace AutoGen.Core; + +public static class StreamingMiddlewareExtension +{ + /// + /// Register a middleware to an existing agent and return a new agent with the middleware. + /// + public static MiddlewareStreamingAgent RegisterStreamingMiddleware( + this TStreamingAgent agent, + IStreamingMiddleware middleware) + where TStreamingAgent : IStreamingAgent + { + var middlewareAgent = new MiddlewareStreamingAgent(agent); + middlewareAgent.UseStreaming(middleware); + + return middlewareAgent; + } + + /// + /// Register a middleware to an existing agent and return a new agent with the middleware. + /// + public static MiddlewareStreamingAgent RegisterStreamingMiddleware( + this MiddlewareStreamingAgent agent, + IStreamingMiddleware middleware) + where TAgent : IStreamingAgent + { + var copyAgent = new MiddlewareStreamingAgent(agent); + copyAgent.UseStreaming(middleware); + + return copyAgent; + } + + + /// + /// Register a middleware to an existing agent and return a new agent with the middleware. + /// + public static MiddlewareStreamingAgent RegisterStreamingMiddleware( + this TAgent agent, + Func>> func, + string? middlewareName = null) + where TAgent : IStreamingAgent + { + var middlewareAgent = new MiddlewareStreamingAgent(agent); + middlewareAgent.UseStreaming(func, middlewareName); + + return middlewareAgent; + } + + /// + /// Register a streaming middleware to an existing agent and return a new agent with the middleware. + /// + public static MiddlewareStreamingAgent RegisterStreamingMiddleware( + this MiddlewareStreamingAgent agent, + Func>> func, + string? middlewareName = null) + where TAgent : IStreamingAgent + { + var copyAgent = new MiddlewareStreamingAgent(agent); + copyAgent.UseStreaming(func, middlewareName); + + return copyAgent; + } + + /// + /// Register a middleware to an existing streaming agent and return a new agent with the middleware. + /// + public static MiddlewareStreamingAgent RegisterMiddleware( + this MiddlewareStreamingAgent streamingAgent, + Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, + string? middlewareName = null) + where TStreamingAgent : IStreamingAgent + { + var copyAgent = new MiddlewareStreamingAgent(streamingAgent); + copyAgent.Use(func, middlewareName); + + return copyAgent; + } + + /// + /// Register a middleware to an existing streaming agent and return a new agent with the middleware. + /// + public static MiddlewareStreamingAgent RegisterMiddleware( + this MiddlewareStreamingAgent streamingAgent, + IMiddleware middleware) + where TStreamingAgent : IStreamingAgent + { + var copyAgent = new MiddlewareStreamingAgent(streamingAgent); + copyAgent.Use(middleware); + + return copyAgent; + } +} diff --git a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs index 0011c74f8da..e1b5546aaab 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs @@ -32,7 +32,7 @@ public class GPTAgent : IStreamingAgent private readonly IDictionary>>? functionMap; private readonly OpenAIClient openAIClient; private readonly string? modelName; - private readonly OpenAIClientAgent _innerAgent; + private readonly OpenAIChatAgent _innerAgent; public GPTAgent( string name, @@ -57,7 +57,7 @@ public class GPTAgent : IStreamingAgent _ => throw new ArgumentException($"Unsupported config type {config.GetType()}"), }; - _innerAgent = new OpenAIClientAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); + _innerAgent = new OpenAIChatAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); Name = name; this.functionMap = functionMap; } @@ -76,7 +76,7 @@ public class GPTAgent : IStreamingAgent this.modelName = modelName; Name = name; this.functionMap = functionMap; - _innerAgent = new OpenAIClientAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); + _innerAgent = new OpenAIChatAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); } public string Name { get; } diff --git a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIClientAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs similarity index 95% rename from dotnet/src/AutoGen.OpenAI/Agent/OpenAIClientAgent.cs rename to dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs index ac5c8691fc8..6ac88f939a5 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIClientAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs @@ -15,13 +15,13 @@ namespace AutoGen.OpenAI; /// /// OpenAI client agent. This agent is a thin wrapper around to provide a simple interface for chat completions. /// To better work with other agents, it's recommended to use which supports more message types and have a better compatibility with other agents. -/// supports the following message types: +/// supports the following message types: /// /// /// where T is : chat request message. /// /// -/// returns the following message types: +/// returns the following message types: /// /// /// where T is : chat response message. @@ -29,7 +29,7 @@ namespace AutoGen.OpenAI; /// /// /// -public class OpenAIClientAgent : IStreamingAgent +public class OpenAIChatAgent : IStreamingAgent { private readonly OpenAIClient openAIClient; private readonly string modelName; @@ -38,7 +38,7 @@ public class OpenAIClientAgent : IStreamingAgent private readonly IEnumerable? _functions; private readonly string _systemMessage; - public OpenAIClientAgent( + public OpenAIChatAgent( OpenAIClient openAIClient, string name, string systemMessage, diff --git a/dotnet/src/AutoGen.SemanticKernel/Middleware/ChatMessageContentConnector.cs b/dotnet/src/AutoGen.SemanticKernel/Middleware/ChatMessageContentConnector.cs index 3d1931b444c..b21d8c1e1da 100644 --- a/dotnet/src/AutoGen.SemanticKernel/Middleware/ChatMessageContentConnector.cs +++ b/dotnet/src/AutoGen.SemanticKernel/Middleware/ChatMessageContentConnector.cs @@ -152,8 +152,9 @@ private IEnumerable ProcessMessageForOthers(IMessage message { TextMessage textMessage => ProcessMessageForOthers(textMessage), MultiModalMessage multiModalMessage => ProcessMessageForOthers(multiModalMessage), + ImageMessage imageMessage => ProcessMessageForOthers(imageMessage), Message m => ProcessMessageForOthers(m), - _ => throw new System.NotImplementedException(), + _ => throw new InvalidOperationException("unsupported message type, only support TextMessage, ImageMessage, MultiModalMessage and Message."), }; } @@ -182,6 +183,14 @@ private IEnumerable ProcessMessageForOthers(TextMessage mess } } + private IEnumerable ProcessMessageForOthers(ImageMessage message) + { + var imageContent = new ImageContent(new Uri(message.Url)); + var collectionItems = new ChatMessageContentItemCollection(); + collectionItems.Add(imageContent); + return [new ChatMessageContent(AuthorRole.User, collectionItems)]; + } + private IEnumerable ProcessMessageForSelf(MultiModalMessage message) { throw new System.InvalidOperationException("MultiModalMessage is not supported in the semantic kernel if it's from self."); diff --git a/dotnet/test/AutoGen.Tests/BasicSampleTest.cs b/dotnet/test/AutoGen.Tests/BasicSampleTest.cs index df96c9e8c9a..2cae5f950db 100644 --- a/dotnet/test/AutoGen.Tests/BasicSampleTest.cs +++ b/dotnet/test/AutoGen.Tests/BasicSampleTest.cs @@ -18,32 +18,31 @@ public BasicSampleTest(ITestOutputHelper output) Console.SetOut(new ConsoleWriter(_output)); } - [ApiKeyFact("OPENAI_API_KEY")] + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] public async Task AssistantAgentTestAsync() { await Example01_AssistantAgent.RunAsync(); } - [ApiKeyFact("OPENAI_API_KEY")] + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] public async Task TwoAgentMathClassTestAsync() { await Example02_TwoAgent_MathChat.RunAsync(); } - [ApiKeyFact("OPENAI_API_KEY")] + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] public async Task AgentFunctionCallTestAsync() { - var instance = new Example03_Agent_FunctionCall(); - await instance.RunAsync(); + await Example03_Agent_FunctionCall.RunAsync(); } - [ApiKeyFact("OPENAI_API_KEY")] + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] public async Task DynamicGroupChatGetMLNetPRTestAsync() { await Example04_Dynamic_GroupChat_Coding_Task.RunAsync(); } - [ApiKeyFact("OPENAI_API_KEY")] + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] public async Task DynamicGroupChatCalculateFibonacciAsync() { await Example07_Dynamic_GroupChat_Calculate_Fibonacci.RunAsync(); diff --git a/dotnet/test/AutoGen.Tests/MiddlewareAgentTest.cs b/dotnet/test/AutoGen.Tests/MiddlewareAgentTest.cs index dfb90324c09..9241c9e94f9 100644 --- a/dotnet/test/AutoGen.Tests/MiddlewareAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/MiddlewareAgentTest.cs @@ -69,6 +69,8 @@ public async Task RegisterMiddlewareTestAsync() return await agent.GenerateReplyAsync(messages, options, ct); }); + middlewareAgent.Should().BeOfType>(); + middlewareAgent.Middlewares.Count().Should().Be(1); var reply = await middlewareAgent.SendAsync("hello"); reply.GetContent().Should().Be("[middleware 0] hello"); reply = await echoAgent.SendAsync("hello"); @@ -82,6 +84,7 @@ public async Task RegisterMiddlewareTestAsync() return await agent.GenerateReplyAsync(messages, options, ct); }); + middlewareAgent.Middlewares.Count().Should().Be(2); reply = await middlewareAgent.SendAsync("hello"); reply.GetContent().Should().Be("[middleware 0] [middleware 1] hello"); diff --git a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs index f2e30fe4da1..27d6eed42fc 100644 --- a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs +++ b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs @@ -109,6 +109,7 @@ public void ToOpenAIChatRequestMessageTest() // text and image message = new MultiModalMessage( + Role.User, [ new TextMessage(Role.User, "Hello", "user"), new ImageMessage(Role.User, "https://example.com/image.png", "user"), diff --git a/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs b/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs new file mode 100644 index 00000000000..6f7abd42e93 --- /dev/null +++ b/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SemanticKernelAgentTest.cs + +using System; +using System.Linq; +using System.Threading.Tasks; +using AutoGen.SemanticKernel; +using AutoGen.SemanticKernel.Middleware; +using FluentAssertions; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace AutoGen.Tests; + +public partial class SemanticKernelAgentTest +{ + /// + /// Get the weather for a location. + /// + /// location + /// + [Function] + public async Task GetWeatherAsync(string location) + { + return $"The weather in {location} is sunny."; + } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] + public async Task BasicConversationTestAsync() + { + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + var builder = Kernel.CreateBuilder() + .AddAzureOpenAIChatCompletion("gpt-35-turbo-16k", endpoint, key); + + var kernel = builder.Build(); + + var skAgent = new SemanticKernelAgent(kernel, "assistant"); + + var chatMessageContent = MessageEnvelope.Create(new ChatMessageContent(AuthorRole.Assistant, "Hello")); + var reply = await skAgent.SendAsync(chatMessageContent); + + reply.Should().BeOfType>(); + reply.As>().From.Should().Be("assistant"); + + // test streaming + var streamingReply = await skAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + + await foreach (var streamingMessage in streamingReply) + { + streamingMessage.Should().BeOfType>(); + streamingMessage.As>().From.Should().Be("assistant"); + } + } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] + public async Task SemanticKernelChatMessageContentConnectorTestAsync() + { + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + var builder = Kernel.CreateBuilder() + .AddAzureOpenAIChatCompletion("gpt-35-turbo-16k", endpoint, key); + + var kernel = builder.Build(); + + var connector = new ChatMessageContentConnector(); + var skAgent = new SemanticKernelAgent(kernel, "assistant") + .RegisterStreamingMiddleware(connector) + .RegisterMiddleware(connector); + + var messages = new IMessage[] + { + MessageEnvelope.Create(new ChatMessageContent(AuthorRole.Assistant, "Hello")), + new TextMessage(Role.Assistant, "Hello", from: "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.Assistant, "Hello", from: "user"), + ], + from: "user"), + }; + + foreach (var message in messages) + { + var reply = await skAgent.SendAsync(message); + + reply.Should().BeOfType(); + reply.As().From.Should().Be("assistant"); + } + + // test streaming + foreach (var message in messages) + { + var reply = await skAgent.GenerateStreamingReplyAsync([message]); + + await foreach (var streamingMessage in reply) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be("assistant"); + } + } + } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] + public async Task SemanticKernelPluginTestAsync() + { + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + var builder = Kernel.CreateBuilder() + .AddAzureOpenAIChatCompletion("gpt-35-turbo-16k", endpoint, key); + + var parameters = this.GetWeatherAsyncFunctionContract.Parameters!.Select(p => new KernelParameterMetadata(p.Name!) + { + Description = p.Description, + DefaultValue = p.DefaultValue, + IsRequired = p.IsRequired, + ParameterType = p.ParameterType, + }); + var function = KernelFunctionFactory.CreateFromMethod(this.GetWeatherAsync, this.GetWeatherAsyncFunctionContract.Name, this.GetWeatherAsyncFunctionContract.Description, parameters); + builder.Plugins.AddFromFunctions("plugins", [function]); + var kernel = builder.Build(); + + var connector = new ChatMessageContentConnector(); + var skAgent = new SemanticKernelAgent(kernel, "assistant") + .RegisterStreamingMiddleware(connector) + .RegisterMiddleware(connector); + + var question = "What is the weather in Seattle?"; + var reply = await skAgent.SendAsync(question); + + reply.GetContent()!.ToLower().Should().Contain("seattle"); + reply.GetContent()!.ToLower().Should().Contain("sunny"); + } +} From b32f5b5a9076a02e8f94d833556de53f6cdc4be5 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Sat, 2 Mar 2024 22:16:23 -0800 Subject: [PATCH 24/27] fix test --- .../OpenAIMessageTests.BasicMessageTest.approved.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt b/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt index aeb5add2293..2cb58f4d88c 100644 --- a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt +++ b/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.approved.txt @@ -102,7 +102,7 @@ ] }, { - "OriginalMessage": "MultiModalMessage(, user)\n\tTextMessage(user, Hello, user)\n\tImageMessage(user, https://example.com/image.png, user)", + "OriginalMessage": "MultiModalMessage(assistant, user)\n\tTextMessage(user, Hello, user)\n\tImageMessage(user, https://example.com/image.png, user)", "ConvertedMessages": [ { "Role": "user", From 9bb9fd91fe55b7f3b2e0de81b3df4c513cfd64ee Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Sat, 2 Mar 2024 22:24:22 -0800 Subject: [PATCH 25/27] bump version --- dotnet/eng/MetaInfo.props | 2 +- dotnet/nuget/NUGET.md | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/dotnet/eng/MetaInfo.props b/dotnet/eng/MetaInfo.props index b40b9c69723..28b7add9e20 100644 --- a/dotnet/eng/MetaInfo.props +++ b/dotnet/eng/MetaInfo.props @@ -1,7 +1,7 @@ - 0.0.8 + 0.0.9 AutoGen https://microsoft.github.io/autogen-for-net/ https://github.com/microsoft/autogen diff --git a/dotnet/nuget/NUGET.md b/dotnet/nuget/NUGET.md index 5616f573023..d4b67b6b20b 100644 --- a/dotnet/nuget/NUGET.md +++ b/dotnet/nuget/NUGET.md @@ -1,4 +1,9 @@ ## AutoGen +#### Update on 0.0.9 (2024-03-02) +- Refactor over @AutoGen.Message and introducing `TextMessage`, `ImageMessage`, `MultiModalMessage` and so on. PR [#1676](https://github.com/microsoft/autogen/pull/1676) +- Add `AutoGen.SemanticKernel` to support seamless integration with Semantic Kernel +- Move the agent contract abstraction to `AutoGen.Core` package. The `AutoGen.Core` package provides the abstraction for message type, agent and group chat and doesn't contain dependencies over `Azure.AI.OpenAI` or `Semantic Kernel`. This is useful when you want to leverage AutoGen's abstraction only and want to avoid introducing any other dependencies. +- Move `GPTAgent`, `OpenAIChatAgent` and all openai-dependencies to `AutoGen.OpenAI` #### Update on 0.0.8 (2024-02-28) - Fix [#1804](https://github.com/microsoft/autogen/pull/1804) - Streaming support for IAgent [#1656](https://github.com/microsoft/autogen/pull/1656) From cbacdc7bce90d9f1e0b898de2238e17ffebb235a Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Sun, 3 Mar 2024 14:28:58 -0800 Subject: [PATCH 26/27] add openaichat test --- .../Example10_SemanticKernel.cs | 2 +- .../Middleware/FunctionCallMiddleware.cs | 29 +-- dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs | 4 +- .../AutoGen.OpenAI/Agent/OpenAIChatAgent.cs | 2 +- ...s => OpenAIChatRequestMessageConnector.cs} | 6 +- ...anticKernelChatMessageContentConnector.cs} | 6 +- .../SemanticKernelAgent.cs | 2 +- ...MessageTests.BasicMessageTest.received.txt | 219 ++++++++++++++++ .../test/AutoGen.Tests/OpenAIChatAgentTest.cs | 244 ++++++++++++++++++ .../test/AutoGen.Tests/OpenAIMessageTests.cs | 6 +- .../AutoGen.Tests/SemanticKernelAgentTest.cs | 4 +- 11 files changed, 481 insertions(+), 43 deletions(-) rename dotnet/src/AutoGen.OpenAI/Middleware/{ChatRequestMessageConnector.cs => OpenAIChatRequestMessageConnector.cs} (98%) rename dotnet/src/AutoGen.SemanticKernel/Middleware/{ChatMessageContentConnector.cs => SemanticKernelChatMessageContentConnector.cs} (97%) create mode 100644 dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.received.txt create mode 100644 dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs diff --git a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs index beda3f71716..f5df8ac5fac 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs @@ -62,7 +62,7 @@ public static async Task RunAsync() Console.WriteLine((reply as IMessage).Content.Items[0].As().Text); // To support more AutoGen built-in IMessage, register skAgent with ChatMessageContentConnector - var connector = new ChatMessageContentConnector(); + var connector = new SemanticKernelChatMessageContentConnector(); var skAgentWithMiddleware = skAgent .RegisterStreamingMiddleware(connector) .RegisterMiddleware(connector) diff --git a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs index c8a68de5147..d00151b32a8 100644 --- a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs @@ -95,7 +95,7 @@ public Task> InvokeAsync(MiddlewareContext c IStreamingMessage? initMessage = default; await foreach (var message in await agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken)) { - if (message is ToolCallMessageUpdate toolCallMessageUpdate) + if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null) { if (initMessage is null) { @@ -110,41 +110,16 @@ await foreach (var message in await agent.GenerateStreamingReplyAsync(context.Me throw new InvalidOperationException("The first message is ToolCallMessage, but the update message is not ToolCallMessageUpdate"); } } - else if (message is TextMessageUpdate textMessageUpdate) - { - if (initMessage is null) - { - initMessage = new TextMessage(textMessageUpdate); - } - else if (initMessage is TextMessage textMessage) - { - textMessage.Update(textMessageUpdate); - } - else - { - throw new InvalidOperationException("The first message is TextMessage, but the update message is not TextMessageUpdate"); - } - } else { - initMessage = message; + yield return message; } - - yield return initMessage; } if (initMessage is ToolCallMessage toolCallMsg) { yield return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent); } - else if (initMessage is not null) - { - yield return initMessage; - } - else - { - throw new InvalidOperationException("The agent returns no message."); - } } private async Task InvokeToolCallMessagesBeforeInvokingAgentAsync(ToolCallMessage toolCallMessage, IAgent agent) diff --git a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs index e1b5546aaab..4e2c3a9c749 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs @@ -86,7 +86,7 @@ public class GPTAgent : IStreamingAgent GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - var oaiConnectorMiddleware = new ChatRequestMessageConnector(); + var oaiConnectorMiddleware = new OpenAIChatRequestMessageConnector(); var agent = this._innerAgent.RegisterMiddleware(oaiConnectorMiddleware); if (this.functionMap is not null) { @@ -102,7 +102,7 @@ public class GPTAgent : IStreamingAgent GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - var oaiConnectorMiddleware = new ChatRequestMessageConnector(); + var oaiConnectorMiddleware = new OpenAIChatRequestMessageConnector(); var agent = this._innerAgent.RegisterStreamingMiddleware(oaiConnectorMiddleware); if (this.functionMap is not null) { diff --git a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs index 6ac88f939a5..bdde3c085c4 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs @@ -41,8 +41,8 @@ public class OpenAIChatAgent : IStreamingAgent public OpenAIChatAgent( OpenAIClient openAIClient, string name, - string systemMessage, string modelName, + string systemMessage = "You are a helpful AI assistant", float temperature = 0.7f, int maxTokens = 1024, IEnumerable? functions = null) diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/ChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs similarity index 98% rename from dotnet/src/AutoGen.OpenAI/Middleware/ChatRequestMessageConnector.cs rename to dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs index 816bcd8b91f..23ddb3f20ad 100644 --- a/dotnet/src/AutoGen.OpenAI/Middleware/ChatRequestMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs @@ -23,16 +23,16 @@ namespace AutoGen.OpenAI.Middleware; /// - where T is /// - where TMessage1 is and TMessage2 is /// -public class ChatRequestMessageConnector : IMiddleware, IStreamingMiddleware +public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddleware { private bool strictMode = false; - public ChatRequestMessageConnector(bool strictMode = false) + public OpenAIChatRequestMessageConnector(bool strictMode = false) { this.strictMode = strictMode; } - public string? Name => nameof(ChatRequestMessageConnector); + public string? Name => nameof(OpenAIChatRequestMessageConnector); public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/AutoGen.SemanticKernel/Middleware/ChatMessageContentConnector.cs b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs similarity index 97% rename from dotnet/src/AutoGen.SemanticKernel/Middleware/ChatMessageContentConnector.cs rename to dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs index b21d8c1e1da..c812f32a5fc 100644 --- a/dotnet/src/AutoGen.SemanticKernel/Middleware/ChatMessageContentConnector.cs +++ b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// ChatMessageContentConnector.cs +// SemanticKernelChatMessageContentConnector.cs using System; using System.Collections.Generic; @@ -32,9 +32,9 @@ namespace AutoGen.SemanticKernel.Middleware; /// - /// - (streaming) /// -public class ChatMessageContentConnector : IMiddleware, IStreamingMiddleware +public class SemanticKernelChatMessageContentConnector : IMiddleware, IStreamingMiddleware { - public string? Name => nameof(ChatMessageContentConnector); + public string? Name => nameof(SemanticKernelChatMessageContentConnector); public async Task InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs index 78603a603d2..8fa6fc03c72 100644 --- a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs +++ b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs @@ -26,7 +26,7 @@ namespace AutoGen.SemanticKernel; /// (streaming) where T is /// /// -/// To support more AutoGen built-in , register with . +/// To support more AutoGen built-in , register with . /// public class SemanticKernelAgent : IStreamingAgent { diff --git a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.received.txt b/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.received.txt new file mode 100644 index 00000000000..7b74fae336a --- /dev/null +++ b/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.received.txt @@ -0,0 +1,219 @@ +[ + { + "OriginalMessage": "TextMessage(system, You are a helpful AI assistant, )", + "ConvertedMessages": [ + { + "Role": "system", + "Content": "You are a helpful AI assistant" + } + ] + }, + { + "OriginalMessage": "TextMessage(user, Hello, user)", + "ConvertedMessages": [ + { + "Role": "user", + "Content": "Hello", + "MultiModaItem": null + } + ] + }, + { + "OriginalMessage": "TextMessage(assistant, How can I help you?, assistant)", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": "How can I help you?", + "TooCall": [], + "FunctionCallName": null, + "FunctionCallArguments": null + } + ] + }, + { + "OriginalMessage": "Message(system, You are a helpful AI assistant, , , )", + "ConvertedMessages": [ + { + "Role": "system", + "Content": "You are a helpful AI assistant" + } + ] + }, + { + "OriginalMessage": "Message(user, Hello, user, , )", + "ConvertedMessages": [ + { + "Role": "user", + "Content": "Hello", + "MultiModaItem": null + } + ] + }, + { + "OriginalMessage": "Message(assistant, How can I help you?, assistant, , )", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": "How can I help you?", + "TooCall": [], + "FunctionCallName": null, + "FunctionCallArguments": null + } + ] + }, + { + "OriginalMessage": "Message(function, result, user, , )", + "ConvertedMessages": [ + { + "Role": "user", + "Content": "result", + "MultiModaItem": null + } + ] + }, + { + "OriginalMessage": "Message(assistant, , assistant, functionName, functionArguments)", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": null, + "TooCall": [], + "FunctionCallName": "functionName", + "FunctionCallArguments": "functionArguments" + } + ] + }, + { + "OriginalMessage": "ImageMessage(user, https://example.com/image.png, user)", + "ConvertedMessages": [ + { + "Role": "user", + "Content": null, + "MultiModaItem": [ + { + "Type": "Image", + "ImageUrl": { + "Url": "https://example.com/image.png", + "Detail": null + } + } + ] + } + ] + }, + { + "OriginalMessage": "MultiModalMessage(assistant, user)\n\tTextMessage(user, Hello, user)\n\tImageMessage(user, https://example.com/image.png, user)", + "ConvertedMessages": [ + { + "Role": "user", + "Content": null, + "MultiModaItem": [ + { + "Type": "Text", + "Text": "Hello" + }, + { + "Type": "Image", + "ImageUrl": { + "Url": "https://example.com/image.png", + "Detail": null + } + } + ] + } + ] + }, + { + "OriginalMessage": "ToolCallMessage(assistant)\n\tToolCall(test, test, )", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": "", + "TooCall": [ + { + "Type": "Function", + "Name": "test", + "Arguments": "test", + "Id": "test" + } + ], + "FunctionCallName": null, + "FunctionCallArguments": null + } + ] + }, + { + "OriginalMessage": "ToolCallResultMessage(user)\n\tToolCall(test, test, result)", + "ConvertedMessages": [ + { + "Role": "tool", + "Content": "result", + "ToolCallId": "test" + } + ] + }, + { + "OriginalMessage": "ToolCallResultMessage(user)\n\tToolCall(result, test, test)\n\tToolCall(result, test, test)", + "ConvertedMessages": [ + { + "Role": "tool", + "Content": "test", + "ToolCallId": "result" + }, + { + "Role": "tool", + "Content": "test", + "ToolCallId": "result" + } + ] + }, + { + "OriginalMessage": "ToolCallMessage(assistant)\n\tToolCall(test, test, )\n\tToolCall(test, test, )", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": "", + "TooCall": [ + { + "Type": "Function", + "Name": "test", + "Arguments": "test", + "Id": "test" + }, + { + "Type": "Function", + "Name": "test", + "Arguments": "test", + "Id": "test" + } + ], + "FunctionCallName": null, + "FunctionCallArguments": null + } + ] + }, + { + "OriginalMessage": "AggregateMessage(assistant)\n\tToolCallMessage(assistant)\n\tToolCall(test, test, )\n\tToolCallResultMessage(assistant)\n\tToolCall(test, test, result)", + "ConvertedMessages": [ + { + "Role": "assistant", + "Content": "", + "TooCall": [ + { + "Type": "Function", + "Name": "test", + "Arguments": "test", + "Id": "test" + } + ], + "FunctionCallName": null, + "FunctionCallArguments": null + }, + { + "Role": "tool", + "Content": "result", + "ToolCallId": "test" + } + ] + } +] \ No newline at end of file diff --git a/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs new file mode 100644 index 00000000000..b83776245d1 --- /dev/null +++ b/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// OpenAIChatAgentTest.cs + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using AutoGen.OpenAI; +using AutoGen.OpenAI.Middleware; +using Azure.AI.OpenAI; +using FluentAssertions; + +namespace AutoGen.Tests; + +public partial class OpenAIChatAgentTest +{ + /// + /// Get the weather for a location. + /// + /// location + /// + [Function] + public async Task GetWeatherAsync(string location) + { + return $"The weather in {location} is sunny."; + } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] + public async Task BasicConversationTestAsync() + { + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openAIChatAgent = new OpenAIChatAgent( + openAIClient: openaiClient, + name: "assistant", + modelName: "gpt-35-turbo-16k"); + + // By default, OpenAIChatClient supports the following message types + // - IMessage + var chatMessageContent = MessageEnvelope.Create(new ChatRequestUserMessage("Hello")); + var reply = await openAIChatAgent.SendAsync(chatMessageContent); + + reply.Should().BeOfType>(); + reply.As>().From.Should().Be("assistant"); + reply.As>().Content.Role.Should().Be(ChatRole.Assistant); + + // test streaming + var streamingReply = await openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + + await foreach (var streamingMessage in streamingReply) + { + streamingMessage.Should().BeOfType>(); + streamingMessage.As>().From.Should().Be("assistant"); + } + } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] + public async Task OpenAIChatMessageContentConnectorTestAsync() + { + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openAIChatAgent = new OpenAIChatAgent( + openAIClient: openaiClient, + name: "assistant", + modelName: "gpt-35-turbo-16k"); + + var openAIChatMessageConnector = new OpenAIChatRequestMessageConnector(); + MiddlewareStreamingAgent assistant = openAIChatAgent + .RegisterStreamingMiddleware(openAIChatMessageConnector) + .RegisterMiddleware(openAIChatMessageConnector); + + var messages = new IMessage[] + { + MessageEnvelope.Create(new ChatRequestUserMessage("Hello")), + new TextMessage(Role.Assistant, "Hello", from: "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.Assistant, "Hello", from: "user"), + ], + from: "user"), + new Message(Role.Assistant, "Hello", from: "user"), // Message type is going to be deprecated, please use TextMessage instead + }; + + foreach (var message in messages) + { + var reply = await assistant.SendAsync(message); + + reply.Should().BeOfType(); + reply.As().From.Should().Be("assistant"); + } + + // test streaming + foreach (var message in messages) + { + var reply = await assistant.GenerateStreamingReplyAsync([message]); + + await foreach (var streamingMessage in reply) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be("assistant"); + } + } + } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] + public async Task OpenAIChatAgentToolCallTestAsync() + { + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openAIChatAgent = new OpenAIChatAgent( + openAIClient: openaiClient, + name: "assistant", + modelName: "gpt-35-turbo-16k"); + + var openAIChatMessageConnector = new OpenAIChatRequestMessageConnector(); + var functionCallMiddleware = new FunctionCallMiddleware( + functions: [this.GetWeatherAsyncFunctionContract]); + MiddlewareStreamingAgent assistant = openAIChatAgent + .RegisterStreamingMiddleware(openAIChatMessageConnector) + .RegisterMiddleware(openAIChatMessageConnector); + + var functionCallAgent = assistant + .RegisterMiddleware(functionCallMiddleware) + .RegisterStreamingMiddleware(functionCallMiddleware); + + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + MessageEnvelope.Create(new ChatRequestUserMessage(question)), + new TextMessage(Role.Assistant, question, from: "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.Assistant, question, from: "user"), + ], + from: "user"), + new Message(Role.Assistant, question, from: "user"), // Message type is going to be deprecated, please use TextMessage instead + }; + + foreach (var message in messages) + { + var reply = await functionCallAgent.SendAsync(message); + + reply.Should().BeOfType(); + reply.As().From.Should().Be("assistant"); + reply.As().ToolCalls.Count().Should().Be(1); + reply.As().ToolCalls.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name); + } + + // test streaming + foreach (var message in messages) + { + var reply = await functionCallAgent.GenerateStreamingReplyAsync([message]); + ToolCallMessage? toolCallMessage = null; + await foreach (var streamingMessage in reply) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be("assistant"); + if (toolCallMessage is null) + { + toolCallMessage = new ToolCallMessage(streamingMessage.As()); + } + else + { + toolCallMessage.Update(streamingMessage.As()); + } + } + + toolCallMessage.Should().NotBeNull(); + toolCallMessage!.From.Should().Be("assistant"); + toolCallMessage.ToolCalls.Count().Should().Be(1); + toolCallMessage.ToolCalls.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name); + } + } + + [ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT")] + public async Task OpenAIChatAgentToolCallInvokingTestAsync() + { + var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable."); + var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable."); + var openaiClient = new OpenAIClient(new Uri(endpoint), new Azure.AzureKeyCredential(key)); + var openAIChatAgent = new OpenAIChatAgent( + openAIClient: openaiClient, + name: "assistant", + modelName: "gpt-35-turbo-16k"); + + var openAIChatMessageConnector = new OpenAIChatRequestMessageConnector(); + var functionCallMiddleware = new FunctionCallMiddleware( + functions: [this.GetWeatherAsyncFunctionContract], + functionMap: new Dictionary>> { { this.GetWeatherAsyncFunctionContract.Name!, this.GetWeatherAsyncWrapper } }); + MiddlewareStreamingAgent assistant = openAIChatAgent + .RegisterStreamingMiddleware(openAIChatMessageConnector) + .RegisterMiddleware(openAIChatMessageConnector); + + var functionCallAgent = assistant + .RegisterMiddleware(functionCallMiddleware) + .RegisterStreamingMiddleware(functionCallMiddleware); + + var question = "What's the weather in Seattle"; + var messages = new IMessage[] + { + MessageEnvelope.Create(new ChatRequestUserMessage(question)), + new TextMessage(Role.Assistant, question, from: "user"), + new MultiModalMessage(Role.Assistant, + [ + new TextMessage(Role.Assistant, question, from: "user"), + ], + from: "user"), + new Message(Role.Assistant, question, from: "user"), // Message type is going to be deprecated, please use TextMessage instead + }; + + foreach (var message in messages) + { + var reply = await functionCallAgent.SendAsync(message); + + reply.Should().BeOfType>(); + reply.From.Should().Be("assistant"); + reply.GetToolCalls()!.Count().Should().Be(1); + reply.GetToolCalls()!.First().FunctionName.Should().Be(this.GetWeatherAsyncFunctionContract.Name); + reply.GetContent()!.ToLower().Should().Contain("seattle"); + } + + // test streaming + foreach (var message in messages) + { + var reply = await functionCallAgent.GenerateStreamingReplyAsync([message]); + await foreach (var streamingMessage in reply) + { + if (streamingMessage is not IMessage) + { + streamingMessage.Should().BeOfType(); + streamingMessage.As().From.Should().Be("assistant"); + } + else + { + streamingMessage.Should().BeOfType>(); + streamingMessage.As().GetContent()!.ToLower().Should().Contain("seattle"); + } + } + } + } +} diff --git a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs index 27d6eed42fc..7bc4701e357 100644 --- a/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs +++ b/dotnet/test/AutoGen.Tests/OpenAIMessageTests.cs @@ -63,7 +63,7 @@ public void BasicMessageTest() message1: new ToolCallMessage("test", "test", "assistant"), message2: new ToolCallResultMessage("result", "test", "test", "assistant"), "assistant"), ]; - var openaiMessageConnectorMiddleware = new ChatRequestMessageConnector(); + var openaiMessageConnectorMiddleware = new OpenAIChatRequestMessageConnector(); var agent = new EchoAgent("assistant"); var oaiMessages = messages.Select(m => (m, openaiMessageConnectorMiddleware.ProcessIncomingMessages(agent, [m]))); @@ -74,7 +74,7 @@ public void BasicMessageTest() public void ToOpenAIChatRequestMessageTest() { var agent = new EchoAgent("assistant"); - var middleware = new ChatRequestMessageConnector(); + var middleware = new OpenAIChatRequestMessageConnector(); // user message IMessage message = new TextMessage(Role.User, "Hello", "user"); @@ -254,7 +254,7 @@ public void ToOpenAIChatRequestMessageTest() public void ToOpenAIChatRequestMessageShortCircuitTest() { var agent = new EchoAgent("assistant"); - var middleware = new ChatRequestMessageConnector(); + var middleware = new OpenAIChatRequestMessageConnector(); ChatRequestMessage[] messages = [ new ChatRequestUserMessage("Hello"), diff --git a/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs b/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs index 6f7abd42e93..2f9dc80da3a 100644 --- a/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs @@ -63,7 +63,7 @@ public async Task SemanticKernelChatMessageContentConnectorTestAsync() var kernel = builder.Build(); - var connector = new ChatMessageContentConnector(); + var connector = new SemanticKernelChatMessageContentConnector(); var skAgent = new SemanticKernelAgent(kernel, "assistant") .RegisterStreamingMiddleware(connector) .RegisterMiddleware(connector); @@ -119,7 +119,7 @@ public async Task SemanticKernelPluginTestAsync() builder.Plugins.AddFromFunctions("plugins", [function]); var kernel = builder.Build(); - var connector = new ChatMessageContentConnector(); + var connector = new SemanticKernelChatMessageContentConnector(); var skAgent = new SemanticKernelAgent(kernel, "assistant") .RegisterStreamingMiddleware(connector) .RegisterMiddleware(connector); From 222a57c2f5930481c063b20ae9cefc47327b6ccd Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Sun, 3 Mar 2024 14:40:17 -0800 Subject: [PATCH 27/27] update --- .../src/AutoGen.Core/Message/ImageMessage.cs | 9 + dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs | 4 +- ...MessageTests.BasicMessageTest.received.txt | 219 ------------------ dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 16 +- 4 files changed, 22 insertions(+), 226 deletions(-) delete mode 100644 dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.received.txt diff --git a/dotnet/src/AutoGen.Core/Message/ImageMessage.cs b/dotnet/src/AutoGen.Core/Message/ImageMessage.cs index 753a6d6e1e4..18ceea0d111 100644 --- a/dotnet/src/AutoGen.Core/Message/ImageMessage.cs +++ b/dotnet/src/AutoGen.Core/Message/ImageMessage.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // ImageMessage.cs +using System; + namespace AutoGen.Core; public class ImageMessage : IMessage @@ -12,6 +14,13 @@ public ImageMessage(Role role, string url, string? from = null) this.Url = url; } + public ImageMessage(Role role, Uri uri, string? from = null) + { + this.Role = role; + this.From = from; + this.Url = uri.ToString(); + } + public Role Role { get; set; } public string Url { get; set; } diff --git a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs index 4e2c3a9c749..344987fd246 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs @@ -57,7 +57,7 @@ public class GPTAgent : IStreamingAgent _ => throw new ArgumentException($"Unsupported config type {config.GetType()}"), }; - _innerAgent = new OpenAIChatAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); + _innerAgent = new OpenAIChatAgent(openAIClient, name, modelName, systemMessage, temperature, maxTokens, functions); Name = name; this.functionMap = functionMap; } @@ -76,7 +76,7 @@ public class GPTAgent : IStreamingAgent this.modelName = modelName; Name = name; this.functionMap = functionMap; - _innerAgent = new OpenAIChatAgent(openAIClient, name, systemMessage, modelName, temperature, maxTokens, functions); + _innerAgent = new OpenAIChatAgent(openAIClient, name, modelName, systemMessage, temperature, maxTokens, functions); } public string Name { get; } diff --git a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.received.txt b/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.received.txt deleted file mode 100644 index 7b74fae336a..00000000000 --- a/dotnet/test/AutoGen.Tests/ApprovalTests/OpenAIMessageTests.BasicMessageTest.received.txt +++ /dev/null @@ -1,219 +0,0 @@ -[ - { - "OriginalMessage": "TextMessage(system, You are a helpful AI assistant, )", - "ConvertedMessages": [ - { - "Role": "system", - "Content": "You are a helpful AI assistant" - } - ] - }, - { - "OriginalMessage": "TextMessage(user, Hello, user)", - "ConvertedMessages": [ - { - "Role": "user", - "Content": "Hello", - "MultiModaItem": null - } - ] - }, - { - "OriginalMessage": "TextMessage(assistant, How can I help you?, assistant)", - "ConvertedMessages": [ - { - "Role": "assistant", - "Content": "How can I help you?", - "TooCall": [], - "FunctionCallName": null, - "FunctionCallArguments": null - } - ] - }, - { - "OriginalMessage": "Message(system, You are a helpful AI assistant, , , )", - "ConvertedMessages": [ - { - "Role": "system", - "Content": "You are a helpful AI assistant" - } - ] - }, - { - "OriginalMessage": "Message(user, Hello, user, , )", - "ConvertedMessages": [ - { - "Role": "user", - "Content": "Hello", - "MultiModaItem": null - } - ] - }, - { - "OriginalMessage": "Message(assistant, How can I help you?, assistant, , )", - "ConvertedMessages": [ - { - "Role": "assistant", - "Content": "How can I help you?", - "TooCall": [], - "FunctionCallName": null, - "FunctionCallArguments": null - } - ] - }, - { - "OriginalMessage": "Message(function, result, user, , )", - "ConvertedMessages": [ - { - "Role": "user", - "Content": "result", - "MultiModaItem": null - } - ] - }, - { - "OriginalMessage": "Message(assistant, , assistant, functionName, functionArguments)", - "ConvertedMessages": [ - { - "Role": "assistant", - "Content": null, - "TooCall": [], - "FunctionCallName": "functionName", - "FunctionCallArguments": "functionArguments" - } - ] - }, - { - "OriginalMessage": "ImageMessage(user, https://example.com/image.png, user)", - "ConvertedMessages": [ - { - "Role": "user", - "Content": null, - "MultiModaItem": [ - { - "Type": "Image", - "ImageUrl": { - "Url": "https://example.com/image.png", - "Detail": null - } - } - ] - } - ] - }, - { - "OriginalMessage": "MultiModalMessage(assistant, user)\n\tTextMessage(user, Hello, user)\n\tImageMessage(user, https://example.com/image.png, user)", - "ConvertedMessages": [ - { - "Role": "user", - "Content": null, - "MultiModaItem": [ - { - "Type": "Text", - "Text": "Hello" - }, - { - "Type": "Image", - "ImageUrl": { - "Url": "https://example.com/image.png", - "Detail": null - } - } - ] - } - ] - }, - { - "OriginalMessage": "ToolCallMessage(assistant)\n\tToolCall(test, test, )", - "ConvertedMessages": [ - { - "Role": "assistant", - "Content": "", - "TooCall": [ - { - "Type": "Function", - "Name": "test", - "Arguments": "test", - "Id": "test" - } - ], - "FunctionCallName": null, - "FunctionCallArguments": null - } - ] - }, - { - "OriginalMessage": "ToolCallResultMessage(user)\n\tToolCall(test, test, result)", - "ConvertedMessages": [ - { - "Role": "tool", - "Content": "result", - "ToolCallId": "test" - } - ] - }, - { - "OriginalMessage": "ToolCallResultMessage(user)\n\tToolCall(result, test, test)\n\tToolCall(result, test, test)", - "ConvertedMessages": [ - { - "Role": "tool", - "Content": "test", - "ToolCallId": "result" - }, - { - "Role": "tool", - "Content": "test", - "ToolCallId": "result" - } - ] - }, - { - "OriginalMessage": "ToolCallMessage(assistant)\n\tToolCall(test, test, )\n\tToolCall(test, test, )", - "ConvertedMessages": [ - { - "Role": "assistant", - "Content": "", - "TooCall": [ - { - "Type": "Function", - "Name": "test", - "Arguments": "test", - "Id": "test" - }, - { - "Type": "Function", - "Name": "test", - "Arguments": "test", - "Id": "test" - } - ], - "FunctionCallName": null, - "FunctionCallArguments": null - } - ] - }, - { - "OriginalMessage": "AggregateMessage(assistant)\n\tToolCallMessage(assistant)\n\tToolCall(test, test, )\n\tToolCallResultMessage(assistant)\n\tToolCall(test, test, result)", - "ConvertedMessages": [ - { - "Role": "assistant", - "Content": "", - "TooCall": [ - { - "Type": "Function", - "Name": "test", - "Arguments": "test", - "Id": "test" - } - ], - "FunctionCallName": null, - "FunctionCallArguments": null - }, - { - "Role": "tool", - "Content": "result", - "ToolCallId": "test" - } - ] - } -] \ No newline at end of file diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index 5c8aae38c46..ab2fa18c089 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -67,25 +67,31 @@ public async Task GPTAgentVisionTestAsync() { nameof(GetHighestLabel), this.GetHighestLabelWrapper }, }); - + var imageUri = new Uri(@"https://raw.githubusercontent.com/microsoft/autogen/main/website/blog/2023-04-21-LLM-tuning-math/img/level2algebra.png"); var oaiMessage = new ChatRequestUserMessage( new ChatMessageTextContentItem("which label has the highest inference cost"), - new ChatMessageImageContentItem(new Uri(@"https://raw.githubusercontent.com/microsoft/autogen/main/website/blog/2023-04-21-LLM-tuning-math/img/level2algebra.png"))); + new ChatMessageImageContentItem(imageUri)); var multiModalMessage = new MultiModalMessage(Role.User, [ new TextMessage(Role.User, "which label has the highest inference cost", from: "user"), - new ImageMessage(Role.User, @"https://raw.githubusercontent.com/microsoft/autogen/main/website/blog/2023-04-21-LLM-tuning-math/img/level2algebra.png", from: "user"), + new ImageMessage(Role.User, imageUri, from: "user"), ], from: "user"); - foreach (var message in new IMessage[] { new MessageEnvelope(oaiMessage), multiModalMessage }) + var imageMessage = new ImageMessage(Role.User, imageUri, from: "user"); + + IMessage[] messages = [ + MessageEnvelope.Create(oaiMessage), + multiModalMessage, + imageMessage, + ]; + foreach (var message in messages) { var response = await visionAgent.SendAsync(message); response.From.Should().Be(visionAgent.Name); var labelResponse = await gpt3Agent.SendAsync(response); labelResponse.From.Should().Be(gpt3Agent.Name); - labelResponse.GetContent().Should().Be("[HIGHEST_LABEL] gpt-4 (n=5) green"); labelResponse.GetToolCalls()!.First().FunctionName.Should().Be(nameof(GetHighestLabel)); } }