Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Copilot chat: refactor chat history #682

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,194 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using SemanticKernel.Service.Model;
using SemanticKernel.Service.Skills;
using SemanticKernel.Service.Storage;

namespace SemanticKernel.Service.Controllers;

/// <summary>
/// Controller for chat history.
TaoChenOSU marked this conversation as resolved.
Show resolved Hide resolved
/// This controller is responsible for creating new chat sessions, retrieving chat sessions,
/// retrieving chat messages, and editing chat sessions.
/// </summary>
[ApiController]
[Authorize]
public class ChatHistoryController : ControllerBase
{
private readonly ILogger<ChatHistoryController> _logger;
private readonly ChatSessionRepository _chatSessionRepository;
private readonly ChatMessageRepository _chatMessageRepository;
private readonly PromptSettings _promptSettings;

/// <summary>
/// Initializes a new instance of the <see cref="ChatHistoryController"/> class.
/// </summary>
/// <param name="logger">The logger.</param>
/// <param name="chatSessionRepository">The chat session repository.</param>
/// <param name="chatMessageRepository">The chat message repository.</param>
/// <param name="promptSettings">The prompt settings.</param>
public ChatHistoryController(
ILogger<ChatHistoryController> logger,
ChatSessionRepository chatSessionRepository,
ChatMessageRepository chatMessageRepository,
PromptSettings promptSettings)
{
this._logger = logger;
this._chatSessionRepository = chatSessionRepository;
this._chatMessageRepository = chatMessageRepository;
this._promptSettings = promptSettings;
}

/// <summary>
/// Create a new chat session and populate the session with the initial bot message.
/// </summary>
/// <param name="chatParameters">Object that contains the parameters to create a new chat.</param>
/// <returns>The HTTP action result.</returns>
[HttpPost]
[Route("chatSession/create")]
[ProducesResponseType(StatusCodes.Status201Created)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
public async Task<IActionResult> CreateChatSessionAsync(
[FromBody] ChatSession chatParameters)
{
var userId = chatParameters.UserId;
var title = chatParameters.Title;

var newChat = new ChatSession(userId, title);
await this._chatSessionRepository.CreateAsync(newChat);

var initialBotMessage = this._promptSettings.InitialBotMessage;
await this.SaveResponseAsync(initialBotMessage, newChat.Id);

this._logger.LogDebug("Created chat session with id {0} for user {1}.", newChat.Id, userId);
return this.CreatedAtAction(nameof(this.GetChatSessionByIdAsync), new { chatId = newChat.Id }, newChat);
}

/// <summary>
/// Get a chat session by id.
/// </summary>
/// <param name="chatId">The chat id.</param>
[HttpGet]
[ActionName("GetChatSessionByIdAsync")]
[Route("chatSession/getChat/{chatId:guid}")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
public async Task<IActionResult> GetChatSessionByIdAsync(Guid chatId)
{
var chat = await this._chatSessionRepository.FindByIdAsync(chatId.ToString());
if (chat == null)
{
return this.NotFound($"Chat of id {chatId} not found.");
}

return this.Ok(chat);
}

/// <summary>
/// Get all chat sessions associated with a user. Return an empty list if no chats are found.
/// The regex pattern that is used to match the user id will match the following format:
/// - 2 period separated groups of one or more hyphen-delimitated alphanumeric strings.
/// The pattern matches two GUIDs in canonical textual representation separated by a period.
/// </summary>
/// <param name="userId">The user id.</param>
[HttpGet]
[Route("chatSession/getAllChats/{userId:regex(([[a-z0-9]]+-)+[[a-z0-9]]+\\.([[a-z0-9]]+-)+[[a-z0-9]]+)}")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
public async Task<IActionResult> GetAllChatSessionsAsync(string userId)
{
var chats = await this._chatSessionRepository.FindByUserIdAsync(userId);
if (chats == null)
{
// Return an empty list if no chats are found
return this.Ok(new List<ChatSession>());
}

return this.Ok(chats);
}

/// <summary>
/// Get all chat messages for a chat session.
/// The list will be ordered with the first entry being the most recent message.
/// </summary>
/// <param name="chatId">The chat id.</param>
/// <param name="startIdx">The start index at which the first message will be returned.</param>
/// <param name="count">The number of messages to return. -1 will return all messages starting from startIdx.</param>
/// [Authorize]
[HttpGet]
[Route("chatSession/getChatMessages/{chatId:guid}")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
public async Task<IActionResult> GetChatMessagesAsync(
Guid chatId,
[FromQuery] int startIdx = 0,
[FromQuery] int count = -1)
{
var chatMessages = await this._chatMessageRepository.FindByChatIdAsync(chatId.ToString());
if (chatMessages == null)
{
return this.NotFound($"No messages found for chat of id {chatId}.");
}

if (startIdx >= chatMessages.Count())
{
return this.BadRequest($"Start index {startIdx} is out of range.");
}
else if (startIdx + count > chatMessages.Count() || count == -1)
{
count = chatMessages.Count() - startIdx;
}

chatMessages = chatMessages.OrderByDescending(m => m.Timestamp).Skip(startIdx).Take(count);
return this.Ok(chatMessages);
}

/// <summary>
/// Edit a chat session.
/// </summary>
/// <param name="chatParameters">Object that contains the parameters to edit the chat.</param>
[HttpPost]
[Route("chatSession/edit")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
public async Task<IActionResult> EditChatSessionAsync([FromBody] ChatSession chatParameters)
{
var chatId = chatParameters.Id;

var chat = await this._chatSessionRepository.FindByIdAsync(chatId.ToString());
if (chat == null)
{
return this.NotFound($"Chat of id {chatId} not found.");
}

chat.Title = chatParameters.Title;
await this._chatSessionRepository.UpdateAsync(chat);

return this.Ok(chat);
}

# region Private

/// <summary>
/// Save a bot response to the chat session.
/// </summary>
/// <param name="response">The bot response.</param>
/// <param name="chatId">The chat id.</param>
private async Task SaveResponseAsync(string response, string chatId)
{
// Make sure the chat session exists
await this._chatSessionRepository.FindByIdAsync(chatId);

var chatMessage = ChatMessage.CreateBotResponseMessage(chatId, response);
await this._chatMessageRepository.CreateAsync(chatMessage);
}

# endregion
}
Expand Up @@ -155,8 +155,6 @@ private string ReadPdfFile(IFormFile file)
fileContent += text;
}

Console.WriteLine(fileContent);

return fileContent;
}

Expand Down
Expand Up @@ -63,13 +63,6 @@ internal static class FunctionLoadingExtensions
);
kernel.ImportSkill(chatSkill, nameof(ChatSkill));

var chatHistorySkill = new ChatHistorySkill(
chatMessageRepository,
chatSessionRepository,
promptSettings
);
kernel.ImportSkill(chatHistorySkill, nameof(ChatHistorySkill));

var documentMemorySkill = new DocumentMemorySkill(promptSettings, documentMemoryOptions);
kernel.ImportSkill(documentMemorySkill, nameof(DocumentMemorySkill));
}
Expand Down
Expand Up @@ -5,7 +5,7 @@
using System.Text.Json.Serialization;
using SemanticKernel.Service.Storage;

namespace SemanticKernel.Service.Skills;
namespace SemanticKernel.Service.Model;

/// <summary>
/// Information about a single chat message.
Expand Down
Expand Up @@ -3,7 +3,7 @@
using System.Text.Json.Serialization;
using SemanticKernel.Service.Storage;

namespace SemanticKernel.Service.Skills;
namespace SemanticKernel.Service.Model;

/// <summary>
/// A chat session
Expand Down
2 changes: 1 addition & 1 deletion samples/apps/copilot-chat-app/webapi/ServiceExtensions.cs
Expand Up @@ -6,7 +6,7 @@
using Microsoft.Identity.Web;
using SemanticKernel.Service.Auth;
using SemanticKernel.Service.Config;
using SemanticKernel.Service.Skills;
using SemanticKernel.Service.Model;
using SemanticKernel.Service.Storage;

namespace SemanticKernel.Service;
Expand Down