Skip to content

Commit

Permalink
.Net Agents - Direct logger association with Agent (#6933)
Browse files Browse the repository at this point in the history
### Motivation and Context
<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Enable direct assignment of `LoggerFactory` for agent managed `Logger`
in support of single-agent (no chat) scenario.

### Description
<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->
- Add `Agent.LoggerFactory` and `Agent.Logger`
- Remove `logger` parameter from `Agent.CreateChannelAsync`
- Remove `logger` parameter from `IChatHistoryHandler.InvokeAsync`
- Update all down-stream definitions


### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
crickman committed Jun 25, 2024
1 parent 50fa21e commit 76e9db4
Show file tree
Hide file tree
Showing 15 changed files with 40 additions and 35 deletions.
2 changes: 2 additions & 0 deletions dotnet/samples/GettingStartedWithAgents/Step7_Logging.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public async Task UseLoggerFactoryWithAgentGroupChatAsync()
Instructions = ReviewerInstructions,
Name = ReviewerName,
Kernel = this.CreateKernelWithChatCompletion(),
LoggerFactory = this.LoggerFactory,
};

ChatCompletionAgent agentWriter =
Expand All @@ -54,6 +55,7 @@ public async Task UseLoggerFactoryWithAgentGroupChatAsync()
Instructions = CopyWriterInstructions,
Name = CopyWriterName,
Kernel = this.CreateKernelWithChatCompletion(),
LoggerFactory = this.LoggerFactory,
};

// Create a chat for agent interaction.
Expand Down
16 changes: 14 additions & 2 deletions dotnet/src/Agents/Abstractions/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

namespace Microsoft.SemanticKernel.Agents;

Expand Down Expand Up @@ -36,6 +37,16 @@ public abstract class Agent
/// </summary>
public string? Name { get; init; }

/// <summary>
/// A <see cref="ILoggerFactory"/> for this <see cref="Agent"/>.
/// </summary>
public ILoggerFactory LoggerFactory { get; init; } = NullLoggerFactory.Instance;

/// <summary>
/// The <see cref="ILogger"/> associated with this <see cref="Agent"/>.
/// </summary>
protected ILogger Logger => this._logger ??= this.LoggerFactory.CreateLogger(this.GetType());

/// <summary>
/// Set of keys to establish channel affinity. Minimum expected key-set:
/// <example>
Expand All @@ -53,12 +64,13 @@ public abstract class Agent
/// <summary>
/// Produce the an <see cref="AgentChannel"/> appropriate for the agent type.
/// </summary>
/// <param name="logger">An agent specific logger.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>An <see cref="AgentChannel"/> appropriate for the agent type.</returns>
/// <remarks>
/// Every agent conversation, or <see cref="AgentChat"/>, will establish one or more <see cref="AgentChannel"/>
/// objects according to the specific <see cref="Agent"/> type.
/// </remarks>
protected internal abstract Task<AgentChannel> CreateChannelAsync(ILogger logger, CancellationToken cancellationToken);
protected internal abstract Task<AgentChannel> CreateChannelAsync(CancellationToken cancellationToken);

private ILogger? _logger;
}
5 changes: 1 addition & 4 deletions dotnet/src/Agents/Abstractions/AgentChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,7 @@ async Task<AgentChannel> GetOrCreateChannelAsync()
{
this.Logger.LogDebug("[{MethodName}] Creating channel for {AgentType}: {AgentId}", nameof(InvokeAgentAsync), agent.GetType(), agent.Id);

// Creating an agent-typed logger for CreateChannelAsync
channel = await agent.CreateChannelAsync(this.LoggerFactory.CreateLogger(agent.GetType()), cancellationToken).ConfigureAwait(false);
// Creating an channel-typed logger for the channel
channel.Logger = this.LoggerFactory.CreateLogger(channel.GetType());
channel = await agent.CreateChannelAsync(cancellationToken).ConfigureAwait(false);

this._agentChannels.Add(channelKey, channel);

Expand Down
6 changes: 3 additions & 3 deletions dotnet/src/Agents/Abstractions/AggregatorAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ protected internal override IEnumerable<string> GetChannelKeys()
}

/// <inheritdoc/>
protected internal override Task<AgentChannel> CreateChannelAsync(ILogger logger, CancellationToken cancellationToken)
protected internal override Task<AgentChannel> CreateChannelAsync(CancellationToken cancellationToken)
{
logger.LogDebug("[{MethodName}] Creating channel {ChannelType}", nameof(CreateChannelAsync), nameof(AggregatorChannel));
this.Logger.LogDebug("[{MethodName}] Creating channel {ChannelType}", nameof(CreateChannelAsync), nameof(AggregatorChannel));

AgentChat chat = chatProvider.Invoke();
AggregatorChannel channel = new(chat);

logger.LogInformation("[{MethodName}] Created channel {ChannelType} ({ChannelMode}) with: {AgentChatType}", nameof(CreateChannelAsync), nameof(AggregatorChannel), this.Mode, chat.GetType());
this.Logger.LogInformation("[{MethodName}] Created channel {ChannelType} ({ChannelMode}) with: {AgentChatType}", nameof(CreateChannelAsync), nameof(AggregatorChannel), this.Mode, chat.GetType());

return Task.FromResult<AgentChannel>(channel);
}
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class ChatHistoryChannel : AgentChannel
throw new KernelException($"Invalid channel binding for agent: {agent.Id} ({agent.GetType().FullName})");
}

await foreach (var message in historyHandler.InvokeAsync(this._history, this.Logger, cancellationToken).ConfigureAwait(false))
await foreach (var message in historyHandler.InvokeAsync(this._history, cancellationToken).ConfigureAwait(false))
{
this._history.Add(message);

Expand Down
11 changes: 8 additions & 3 deletions dotnet/src/Agents/Abstractions/ChatHistoryKernelAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@ protected internal sealed override IEnumerable<string> GetChannelKeys()
}

/// <inheritdoc/>
protected internal sealed override Task<AgentChannel> CreateChannelAsync(ILogger logger, CancellationToken cancellationToken)
protected internal sealed override Task<AgentChannel> CreateChannelAsync(CancellationToken cancellationToken)
{
return Task.FromResult<AgentChannel>(new ChatHistoryChannel());
ChatHistoryChannel channel =
new()
{
Logger = this.LoggerFactory.CreateLogger<ChatHistoryChannel>()
};

return Task.FromResult<AgentChannel>(channel);
}

/// <inheritdoc/>
public abstract IAsyncEnumerable<ChatMessageContent> InvokeAsync(
IReadOnlyList<ChatMessageContent> history,
ILogger logger,
CancellationToken cancellationToken = default);
}
3 changes: 0 additions & 3 deletions dotnet/src/Agents/Abstractions/IChatHistoryHandler.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Threading;
using Microsoft.Extensions.Logging;

namespace Microsoft.SemanticKernel.Agents;

Expand All @@ -14,11 +13,9 @@ public interface IChatHistoryHandler
/// Entry point for calling into an agent from a a <see cref="ChatHistoryChannel"/>.
/// </summary>
/// <param name="history">The chat history at the point the channel is created.</param>
/// <param name="logger">The logger associated with the <see cref="ChatHistoryChannel"/></param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Asynchronous enumeration of messages.</returns>
IAsyncEnumerable<ChatMessageContent> InvokeAsync(
IReadOnlyList<ChatMessageContent> history,
ILogger logger,
CancellationToken cancellationToken = default);
}
7 changes: 3 additions & 4 deletions dotnet/src/Agents/Core/ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ public sealed class ChatCompletionAgent : ChatHistoryKernelAgent
/// <inheritdoc/>
public override async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
IReadOnlyList<ChatMessageContent> history,
ILogger logger,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var chatCompletionService = this.Kernel.GetRequiredService<IChatCompletionService>();
Expand All @@ -38,7 +37,7 @@ public sealed class ChatCompletionAgent : ChatHistoryKernelAgent

int messageCount = chat.Count;

logger.LogDebug("[{MethodName}] Invoking {ServiceType}.", nameof(InvokeAsync), chatCompletionService.GetType());
this.Logger.LogDebug("[{MethodName}] Invoking {ServiceType}.", nameof(InvokeAsync), chatCompletionService.GetType());

IReadOnlyList<ChatMessageContent> messages =
await chatCompletionService.GetChatMessageContentsAsync(
Expand All @@ -47,9 +46,9 @@ public sealed class ChatCompletionAgent : ChatHistoryKernelAgent
this.Kernel,
cancellationToken).ConfigureAwait(false);

if (logger.IsEnabled(LogLevel.Information)) // Avoid boxing if not enabled
if (this.Logger.IsEnabled(LogLevel.Information)) // Avoid boxing if not enabled
{
logger.LogInformation("[{MethodName}] Invoked {ServiceType} with message count: {MessageCount}.", nameof(InvokeAsync), chatCompletionService.GetType(), messages.Count);
this.Logger.LogInformation("[{MethodName}] Invoked {ServiceType} with message count: {MessageCount}.", nameof(InvokeAsync), chatCompletionService.GetType(), messages.Count);
}

// Capture mutated messages related function calling / tools
Expand Down
6 changes: 3 additions & 3 deletions dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,13 @@ protected override IEnumerable<string> GetChannelKeys()
}

/// <inheritdoc/>
protected override async Task<AgentChannel> CreateChannelAsync(ILogger logger, CancellationToken cancellationToken)
protected override async Task<AgentChannel> CreateChannelAsync(CancellationToken cancellationToken)
{
logger.LogDebug("[{MethodName}] Creating assistant thread", nameof(CreateChannelAsync));
this.Logger.LogDebug("[{MethodName}] Creating assistant thread", nameof(CreateChannelAsync));

AssistantThread thread = await this._client.CreateThreadAsync(cancellationToken).ConfigureAwait(false);

logger.LogInformation("[{MethodName}] Created assistant thread: {ThreadId}", nameof(CreateChannelAsync), thread.Id);
this.Logger.LogInformation("[{MethodName}] Created assistant thread: {ThreadId}", nameof(CreateChannelAsync), thread.Id);

return new OpenAIAssistantChannel(this._client, thread.Id, this._config.Polling);
}
Expand Down
3 changes: 1 addition & 2 deletions dotnet/src/Agents/UnitTests/AgentChannelTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Xunit;
Expand Down Expand Up @@ -68,7 +67,7 @@ private sealed class NextAgent : TestAgent;

private class TestAgent : KernelAgent
{
protected internal override Task<AgentChannel> CreateChannelAsync(ILogger logger, CancellationToken cancellationToken)
protected internal override Task<AgentChannel> CreateChannelAsync(CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
Expand Down
2 changes: 0 additions & 2 deletions dotnet/src/Agents/UnitTests/AgentChatTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.ChatCompletion;
Expand Down Expand Up @@ -137,7 +136,6 @@ private sealed class TestAgent : ChatHistoryKernelAgent

public override async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
IReadOnlyList<ChatMessageContent> history,
ILogger logger,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await Task.Delay(0, cancellationToken);
Expand Down
3 changes: 1 addition & 2 deletions dotnet/src/Agents/UnitTests/AggregatorAgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.ChatCompletion;
Expand Down Expand Up @@ -88,7 +87,7 @@ private static Mock<ChatHistoryKernelAgent> CreateMockAgent()
Mock<ChatHistoryKernelAgent> agent = new();

ChatMessageContent[] messages = [new ChatMessageContent(AuthorRole.Assistant, "test agent")];
agent.Setup(a => a.InvokeAsync(It.IsAny<IReadOnlyList<ChatMessageContent>>(), It.IsAny<ILogger>(), It.IsAny<CancellationToken>())).Returns(() => messages.ToAsyncEnumerable());
agent.Setup(a => a.InvokeAsync(It.IsAny<IReadOnlyList<ChatMessageContent>>(), It.IsAny<CancellationToken>())).Returns(() => messages.ToAsyncEnumerable());

return agent;
}
Expand Down
3 changes: 1 addition & 2 deletions dotnet/src/Agents/UnitTests/ChatHistoryChannelTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Xunit;
Expand All @@ -30,7 +29,7 @@ public async Task VerifyAgentWithoutIChatHistoryHandlerAsync()

private sealed class TestAgent : KernelAgent
{
protected internal override Task<AgentChannel> CreateChannelAsync(ILogger logger, CancellationToken cancellationToken)
protected internal override Task<AgentChannel> CreateChannelAsync(CancellationToken cancellationToken)
{
throw new NotImplementedException();
}
Expand Down
3 changes: 1 addition & 2 deletions dotnet/src/Agents/UnitTests/Core/AgentGroupChatTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.Agents.Chat;
Expand Down Expand Up @@ -199,7 +198,7 @@ private static Mock<ChatHistoryKernelAgent> CreateMockAgent()
Mock<ChatHistoryKernelAgent> agent = new();

ChatMessageContent[] messages = [new ChatMessageContent(AuthorRole.Assistant, "test")];
agent.Setup(a => a.InvokeAsync(It.IsAny<IReadOnlyList<ChatMessageContent>>(), It.IsAny<ILogger>(), It.IsAny<CancellationToken>())).Returns(() => messages.ToAsyncEnumerable());
agent.Setup(a => a.InvokeAsync(It.IsAny<IReadOnlyList<ChatMessageContent>>(), It.IsAny<CancellationToken>())).Returns(() => messages.ToAsyncEnumerable());

return agent;
}
Expand Down
3 changes: 1 addition & 2 deletions dotnet/src/Agents/UnitTests/Core/ChatCompletionAgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.ChatCompletion;
Expand Down Expand Up @@ -60,7 +59,7 @@ public async Task VerifyChatCompletionAgentInvocationAsync()
ExecutionSettings = new(),
};

var result = await agent.InvokeAsync([], NullLogger.Instance).ToArrayAsync();
var result = await agent.InvokeAsync([]).ToArrayAsync();

Assert.Single(result);

Expand Down

0 comments on commit 76e9db4

Please sign in to comment.