Skip to content

Commit

Permalink
Copilot chat: refactor chat history (#682)
Browse files Browse the repository at this point in the history
### Motivation and Context
The chat history skill handles the creation of new chat sessions,
retrieval of chat sessions and messages, and editing of chat titles. It
was unnecessary and improper to use a skill to handle non-AI related
task.

### Description
Replace the chat history skill with a controller that provides APIs to
handle the above-mentioned tasks.
  • Loading branch information
TaoChenOSU committed Apr 28, 2023
1 parent b6db9a2 commit 3362c20
Show file tree
Hide file tree
Showing 24 changed files with 385 additions and 326 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using SemanticKernel.Service.Model;
using SemanticKernel.Service.Skills;
using SemanticKernel.Service.Storage;

namespace SemanticKernel.Service.Controllers;

/// <summary>
/// Controller for chat history.
/// This controller is responsible for creating new chat sessions, retrieving chat sessions,
/// retrieving chat messages, and editing chat sessions.
/// </summary>
[ApiController]
[Authorize]
public class ChatHistoryController : ControllerBase
{
private readonly ILogger<ChatHistoryController> _logger;
private readonly ChatSessionRepository _chatSessionRepository;
private readonly ChatMessageRepository _chatMessageRepository;
private readonly PromptSettings _promptSettings;

/// <summary>
/// Initializes a new instance of the <see cref="ChatHistoryController"/> class.
/// </summary>
/// <param name="logger">The logger.</param>
/// <param name="chatSessionRepository">The chat session repository.</param>
/// <param name="chatMessageRepository">The chat message repository.</param>
/// <param name="promptSettings">The prompt settings.</param>
public ChatHistoryController(
ILogger<ChatHistoryController> logger,
ChatSessionRepository chatSessionRepository,
ChatMessageRepository chatMessageRepository,
PromptSettings promptSettings)
{
this._logger = logger;
this._chatSessionRepository = chatSessionRepository;
this._chatMessageRepository = chatMessageRepository;
this._promptSettings = promptSettings;
}

/// <summary>
/// Create a new chat session and populate the session with the initial bot message.
/// </summary>
/// <param name="chatParameters">Object that contains the parameters to create a new chat.</param>
/// <returns>The HTTP action result.</returns>
[HttpPost]
[Route("chatSession/create")]
[ProducesResponseType(StatusCodes.Status201Created)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
public async Task<IActionResult> CreateChatSessionAsync(
[FromBody] ChatSession chatParameters)
{
var userId = chatParameters.UserId;
var title = chatParameters.Title;

var newChat = new ChatSession(userId, title);
await this._chatSessionRepository.CreateAsync(newChat);

var initialBotMessage = this._promptSettings.InitialBotMessage;
await this.SaveResponseAsync(initialBotMessage, newChat.Id);

this._logger.LogDebug("Created chat session with id {0} for user {1}.", newChat.Id, userId);
return this.CreatedAtAction(nameof(this.GetChatSessionByIdAsync), new { chatId = newChat.Id }, newChat);
}

/// <summary>
/// Get a chat session by id.
/// </summary>
/// <param name="chatId">The chat id.</param>
[HttpGet]
[ActionName("GetChatSessionByIdAsync")]
[Route("chatSession/getChat/{chatId:guid}")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
public async Task<IActionResult> GetChatSessionByIdAsync(Guid chatId)
{
var chat = await this._chatSessionRepository.FindByIdAsync(chatId.ToString());
if (chat == null)
{
return this.NotFound($"Chat of id {chatId} not found.");
}

return this.Ok(chat);
}

/// <summary>
/// Get all chat sessions associated with a user. Return an empty list if no chats are found.
/// The regex pattern that is used to match the user id will match the following format:
/// - 2 period separated groups of one or more hyphen-delimitated alphanumeric strings.
/// The pattern matches two GUIDs in canonical textual representation separated by a period.
/// </summary>
/// <param name="userId">The user id.</param>
[HttpGet]
[Route("chatSession/getAllChats/{userId:regex(([[a-z0-9]]+-)+[[a-z0-9]]+\\.([[a-z0-9]]+-)+[[a-z0-9]]+)}")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
public async Task<IActionResult> GetAllChatSessionsAsync(string userId)
{
var chats = await this._chatSessionRepository.FindByUserIdAsync(userId);
if (chats == null)
{
// Return an empty list if no chats are found
return this.Ok(new List<ChatSession>());
}

return this.Ok(chats);
}

/// <summary>
/// Get all chat messages for a chat session.
/// The list will be ordered with the first entry being the most recent message.
/// </summary>
/// <param name="chatId">The chat id.</param>
/// <param name="startIdx">The start index at which the first message will be returned.</param>
/// <param name="count">The number of messages to return. -1 will return all messages starting from startIdx.</param>
/// [Authorize]
[HttpGet]
[Route("chatSession/getChatMessages/{chatId:guid}")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
public async Task<IActionResult> GetChatMessagesAsync(
Guid chatId,
[FromQuery] int startIdx = 0,
[FromQuery] int count = -1)
{
var chatMessages = await this._chatMessageRepository.FindByChatIdAsync(chatId.ToString());
if (chatMessages == null)
{
return this.NotFound($"No messages found for chat of id {chatId}.");
}

if (startIdx >= chatMessages.Count())
{
return this.BadRequest($"Start index {startIdx} is out of range.");
}
else if (startIdx + count > chatMessages.Count() || count == -1)
{
count = chatMessages.Count() - startIdx;
}

chatMessages = chatMessages.OrderByDescending(m => m.Timestamp).Skip(startIdx).Take(count);
return this.Ok(chatMessages);
}

/// <summary>
/// Edit a chat session.
/// </summary>
/// <param name="chatParameters">Object that contains the parameters to edit the chat.</param>
[HttpPost]
[Route("chatSession/edit")]
[ProducesResponseType(StatusCodes.Status200OK)]
[ProducesResponseType(StatusCodes.Status400BadRequest)]
[ProducesResponseType(StatusCodes.Status404NotFound)]
public async Task<IActionResult> EditChatSessionAsync([FromBody] ChatSession chatParameters)
{
var chatId = chatParameters.Id;

var chat = await this._chatSessionRepository.FindByIdAsync(chatId.ToString());
if (chat == null)
{
return this.NotFound($"Chat of id {chatId} not found.");
}

chat.Title = chatParameters.Title;
await this._chatSessionRepository.UpdateAsync(chat);

return this.Ok(chat);
}

# region Private

/// <summary>
/// Save a bot response to the chat session.
/// </summary>
/// <param name="response">The bot response.</param>
/// <param name="chatId">The chat id.</param>
private async Task SaveResponseAsync(string response, string chatId)
{
// Make sure the chat session exists
await this._chatSessionRepository.FindByIdAsync(chatId);

var chatMessage = ChatMessage.CreateBotResponseMessage(chatId, response);
await this._chatMessageRepository.CreateAsync(chatMessage);
}

# endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,6 @@ private string ReadPdfFile(IFormFile file)
fileContent += text;
}

Console.WriteLine(fileContent);

return fileContent;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,6 @@ internal static class FunctionLoadingExtensions
);
kernel.ImportSkill(chatSkill, nameof(ChatSkill));

var chatHistorySkill = new ChatHistorySkill(
chatMessageRepository,
chatSessionRepository,
promptSettings
);
kernel.ImportSkill(chatHistorySkill, nameof(ChatHistorySkill));

var documentMemorySkill = new DocumentMemorySkill(promptSettings, documentMemoryOptions);
kernel.ImportSkill(documentMemorySkill, nameof(DocumentMemorySkill));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System.Text.Json.Serialization;
using SemanticKernel.Service.Storage;

namespace SemanticKernel.Service.Skills;
namespace SemanticKernel.Service.Model;

/// <summary>
/// Information about a single chat message.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
using System.Text.Json.Serialization;
using SemanticKernel.Service.Storage;

namespace SemanticKernel.Service.Skills;
namespace SemanticKernel.Service.Model;

/// <summary>
/// A chat session
Expand Down
2 changes: 1 addition & 1 deletion samples/apps/copilot-chat-app/webapi/ServiceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using Microsoft.Identity.Web;
using SemanticKernel.Service.Auth;
using SemanticKernel.Service.Config;
using SemanticKernel.Service.Skills;
using SemanticKernel.Service.Model;
using SemanticKernel.Service.Storage;

namespace SemanticKernel.Service;
Expand Down

0 comments on commit 3362c20

Please sign in to comment.