Skip to content

Commit

Permalink
Merge pull request #18 from cnblogs/support-function-behavior
Browse files Browse the repository at this point in the history
feat: support auto invoke functions for non-stream chat
  • Loading branch information
ikesnowy committed Mar 18, 2024
2 parents b2d1e95 + 041318f commit a191cfa
Show file tree
Hide file tree
Showing 26 changed files with 925 additions and 60 deletions.
11 changes: 8 additions & 3 deletions src/KernelMemory.DashScope/DashScopeTextEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Cnblogs.DashScope.Sdk;
using Cnblogs.DashScope.Sdk.TextEmbedding;
using Cnblogs.DashScope.Core;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI;

Expand Down Expand Up @@ -30,7 +29,13 @@ public int CountTokens(string text)
string text,
CancellationToken cancellationToken = new())
{
var result = await dashScopeClient.GetTextEmbeddingsAsync(modelId, [text], null, cancellationToken);
var result = await dashScopeClient.GetEmbeddingsAsync(
new ModelRequest<TextEmbeddingInput, ITextEmbeddingParameters>
{
Input = new TextEmbeddingInput { Texts = [text] },
Model = modelId
},
cancellationToken);
return result.Output.Embeddings[0].Embedding;
}

Expand Down
2 changes: 1 addition & 1 deletion src/KernelMemory.DashScope/DashScopeTextGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using System.Runtime.CompilerServices;
using Cnblogs.DashScope.Sdk;
using Cnblogs.DashScope.Core;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.Diagnostics;
Expand Down
2 changes: 1 addition & 1 deletion src/KernelMemory.DashScope/DependencyInjector.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Cnblogs.DashScope.Sdk;
using Cnblogs.DashScope.Core;
using Cnblogs.KernelMemory.AI.DashScope;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
Expand Down
4 changes: 2 additions & 2 deletions src/KernelMemory.DashScope/KernelMemory.DashScope.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

<ItemGroup>
<PackageReference Include="Microsoft.DeepDev.TokenizerLib" Version="1.3.3" />
<PackageReference Include="Microsoft.KernelMemory.Abstractions" Version="0.32.240307.1"/>
<PackageReference Include="Cnblogs.DashScope.Sdk" Version="0.0.3"/>
<PackageReference Include="Microsoft.KernelMemory.Abstractions" Version="0.34.240313.1" />
<PackageReference Include="Cnblogs.DashScope.Core" Version="0.2.0" />
</ItemGroup>

<ItemGroup>
Expand Down
208 changes: 191 additions & 17 deletions src/SemanticKernel.DashScope/DashScopeChatCompletionService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System.Runtime.CompilerServices;
using Cnblogs.DashScope.Sdk;
using System.Text.Json;
using Cnblogs.DashScope.Core;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Services;
Expand All @@ -15,45 +17,132 @@ public sealed class DashScopeChatCompletionService : IChatCompletionService, ITe
private readonly IDashScopeClient _dashScopeClient;
private readonly Dictionary<string, object?> _attributes = new();
private readonly string _modelId;
private readonly ILogger<DashScopeChatCompletionService> _logger;

/// <summary>
/// Creates a new DashScope chat completion service.
/// </summary>
/// <param name="modelId"></param>
/// <param name="dashScopeClient"></param>
public DashScopeChatCompletionService(string modelId, IDashScopeClient dashScopeClient)
/// <param name="logger"></param>
public DashScopeChatCompletionService(
string modelId,
IDashScopeClient dashScopeClient,
ILogger<DashScopeChatCompletionService> logger)
{
_dashScopeClient = dashScopeClient;
_modelId = modelId;
_logger = logger;
_attributes.Add(AIServiceExtensions.ModelIdKey, _modelId);
}

/// <inheritdoc />
public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(
ChatHistory chatHistory,
ChatHistory chat,
PromptExecutionSettings? executionSettings = null,
Kernel? kernel = null,
CancellationToken cancellationToken = default)
{
var chatMessages = chatHistory.ToChatMessages();
var chatParameters = DashScopePromptExecutionSettings.FromPromptExecutionSettings(executionSettings);
chatParameters ??= new DashScopePromptExecutionSettings();
chatParameters.IncrementalOutput = false;
chatParameters.ResultFormat = ResultFormats.Message;
var response = await _dashScopeClient.GetTextCompletionAsync(
new ModelRequest<TextGenerationInput, ITextGenerationParameters>
chatParameters.ToolCallBehavior?.ConfigureOptions(kernel, chatParameters);

var autoInvoke = kernel is not null && chatParameters.ToolCallBehavior?.MaximumAutoInvokeAttempts > 0;
for (var it = 1;; it++)
{
var response = await _dashScopeClient.GetTextCompletionAsync(
new ModelRequest<TextGenerationInput, ITextGenerationParameters>
{
Input = new TextGenerationInput { Messages = chat.ToChatMessages() },
Model = string.IsNullOrEmpty(chatParameters.ModelId) ? _modelId : chatParameters.ModelId,
Parameters = chatParameters
},
cancellationToken);
CaptureTokenUsage(response.Usage);
EnsureChoiceExists(response.Output.Choices);
var message = response.Output.Choices![0].Message;
var chatMessageContent = new DashScopeChatMessageContent(
new AuthorRole(message.Role),
message.Content,
name: null,
toolCalls: message.ToolCalls,
metadata: response.ToMetaData());
if (autoInvoke == false || message.ToolCalls is null)
{
Input = new TextGenerationInput { Messages = chatMessages },
Model = string.IsNullOrEmpty(chatParameters.ModelId) ? _modelId : chatParameters.ModelId,
Parameters = chatParameters
},
cancellationToken);
var message = response.Output.Choices![0].Message;
var chatMessageContent = new ChatMessageContent(
new AuthorRole(message.Role),
message.Content,
metadata: response.ToMetaData());
return [chatMessageContent];
// no needs to invoke tool
return [chatMessageContent];
}

LogToolCalls(message.ToolCalls);
chat.Add(chatMessageContent);

foreach (var call in message.ToolCalls)
{
if (call.Type is not ToolTypes.Function || call.Function is null)
{
AddResponseMessage(chat, null, "Error: Tool call was not a function call.", call.Id);
continue;
}

// ensure not calling function that was not included in request list.
if (chatParameters.Tools?.Any(
x => string.Equals(x.Function?.Name, call.Function.Name, StringComparison.OrdinalIgnoreCase))
!= true)
{
AddResponseMessage(
chat,
null,
"Error: Function call requests for a function that wasn't defined.",
call.Id);
continue;
}

object? callResult;
try
{
if (kernel!.Plugins.TryGetKernelFunctionAndArguments(
call.Function,
out var kernelFunction,
out var kernelArguments)
== false)
{
AddResponseMessage(chat, null, "Error: Requested function could not be found.", call.Id);
continue;
}

var functionResult = await kernelFunction.InvokeAsync(kernel, kernelArguments, cancellationToken);
callResult = functionResult.GetValue<object>() ?? string.Empty;
}
catch (JsonException)
{
AddResponseMessage(chat, null, "Error: Function call arguments were invalid JSON.", call.Id);
continue;
}
catch (Exception)
{
AddResponseMessage(chat, null, "Error: Exception while invoking function. {e.Message}", call.Id);
continue;
}

var stringResult = ProcessFunctionResult(callResult, chatParameters.ToolCallBehavior);
AddResponseMessage(chat, stringResult, null, call.Id);
}

chatParameters.Tools?.Clear();
chatParameters.ToolCallBehavior?.ConfigureOptions(kernel, chatParameters);
if (it >= chatParameters.ToolCallBehavior!.MaximumAutoInvokeAttempts)
{
autoInvoke = false;
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug(
"Maximum auto-invoke ({MaximumAutoInvoke}) reached",
chatParameters.ToolCallBehavior!.MaximumAutoInvokeAttempts);
}
}
}
}

/// <inheritdoc />
Expand All @@ -68,6 +157,7 @@ public DashScopeChatCompletionService(string modelId, IDashScopeClient dashScope
var parameters = DashScopePromptExecutionSettings.FromPromptExecutionSettings(executionSettings);
parameters.IncrementalOutput = true;
parameters.ResultFormat = ResultFormats.Message;
parameters.ToolCallBehavior?.ConfigureOptions(kernel, parameters);
var responses = _dashScopeClient.GetTextCompletionStreamAsync(
new ModelRequest<TextGenerationInput, ITextGenerationParameters>
{
Expand Down Expand Up @@ -141,4 +231,88 @@ await foreach (var response in responses)
metadata: response.ToMetaData());
}
}

private void CaptureTokenUsage(TextGenerationTokenUsage? usage)
{
if (usage is null)
{
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("Usage info is not available");
}

return;
}

if (_logger.IsEnabled(LogLevel.Information))
{
_logger.LogInformation(
"Input tokens: {InputTokens}. Output tokens: {CompletionTokens}. Total tokens: {TotalTokens}",
usage.InputTokens,
usage.OutputTokens,
usage.TotalTokens);
}
}

private void LogToolCalls(IReadOnlyCollection<ToolCall>? calls)
{
if (calls is null)
{
return;
}

if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("Tool requests: {Requests}", calls.Count);
}

if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.LogTrace(
"Function call requests: {Requests}",
string.Join(", ", calls.Select(ftc => $"{ftc.Function?.Name}({ftc.Function?.Arguments})")));
}
}

private void AddResponseMessage(ChatHistory chat, string? result, string? errorMessage, string? toolId)
{
// Log any error
if (errorMessage is not null && _logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("Failed to handle tool request ({ToolId}). {Error}", toolId, errorMessage);
}

// Add the tool response message to both the chat options and to the chat history.
result ??= errorMessage ?? string.Empty;
chat.Add(new DashScopeChatMessageContent(AuthorRole.Tool, result, name: toolId));
}

private static void EnsureChoiceExists(List<TextGenerationChoice>? choices)
{
if (choices is null || choices.Count == 0)
{
throw new KernelException("No choice was returned from model");
}
}

private static string ProcessFunctionResult(object functionResult, ToolCallBehavior? toolCallBehavior)
{
if (functionResult is string stringResult)
{
return stringResult;
}

// This is an optimization to use ChatMessageContent content directly
// without unnecessary serialization of the whole message content class.
if (functionResult is ChatMessageContent chatMessageContent)
{
return chatMessageContent.ToString();
}

// For polymorphic serialization of unknown in advance child classes of the KernelContent class,
// a corresponding JsonTypeInfoResolver should be provided via the JsonSerializerOptions.TypeInfoResolver property.
// For more details about the polymorphic serialization, see the article at:
// https://learn.microsoft.com/en-us/dotnet/standard/serialization/system-text-json/polymorphism?pivots=dotnet-8-0
return JsonSerializer.Serialize(functionResult, toolCallBehavior?.ToolCallResultSerializerOptions);
}
}
27 changes: 27 additions & 0 deletions src/SemanticKernel.DashScope/DashScopeChatMessageContent.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using Cnblogs.DashScope.Core;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Cnblogs.SemanticKernel.Connectors.DashScope;

/// <summary>
/// DashScope specialized message content
/// </summary>
public class DashScopeChatMessageContent(
AuthorRole role,
string content,
Dictionary<string, object?>? metadata = null,
string? name = null,
List<ToolCall>? toolCalls = null)
: ChatMessageContent(role, content, metadata: metadata)
{
/// <summary>
/// The name of tool if role is tool.
/// </summary>
public string? Name { get; } = name;

/// <summary>
/// Optional tool calls.
/// </summary>
public List<ToolCall>? ToolCalls { get; } = toolCalls;
}
13 changes: 11 additions & 2 deletions src/SemanticKernel.DashScope/DashScopeMapper.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Cnblogs.DashScope.Sdk;
using Cnblogs.DashScope.Core;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Cnblogs.SemanticKernel.Connectors.DashScope;
Expand All @@ -7,7 +7,16 @@ internal static class DashScopeMapper
{
public static List<ChatMessage> ToChatMessages(this ChatHistory history)
{
return history.Select(x => new ChatMessage(x.Role.Label, x.Content ?? string.Empty)).ToList();
return history.Select(
x =>
{
if (x is DashScopeChatMessageContent d)
{
return new ChatMessage(x.Role.Label, x.Content ?? string.Empty, d.Name, ToolCalls: d.ToolCalls);
}
return new ChatMessage(x.Role.Label, x.Content ?? string.Empty);
}).ToList();
}

public static Dictionary<string, object?>? ToMetaData<TOutput, TUsage>(
Expand Down
Loading

0 comments on commit a191cfa

Please sign in to comment.