diff --git a/scripts/deploy/main.bicep b/scripts/deploy/main.bicep index 2eb9636d9..4df201eac 100644 --- a/scripts/deploy/main.bicep +++ b/scripts/deploy/main.bicep @@ -637,7 +637,7 @@ resource messageContainer 'Microsoft.DocumentDB/databaseAccounts/sqlDatabases/co } partitionKey: { paths: [ - '/id' + '/chatId' ] kind: 'Hash' version: 2 @@ -699,7 +699,7 @@ resource participantContainer 'Microsoft.DocumentDB/databaseAccounts/sqlDatabase } partitionKey: { paths: [ - '/id' + '/userId' ] kind: 'Hash' version: 2 @@ -730,7 +730,7 @@ resource memorySourcesContainer 'Microsoft.DocumentDB/databaseAccounts/sqlDataba } partitionKey: { paths: [ - '/id' + '/chatId' ] kind: 'Hash' version: 2 diff --git a/webapi/Controllers/ChatHistoryController.cs b/webapi/Controllers/ChatHistoryController.cs index 58f1078c7..d2c0f0903 100644 --- a/webapi/Controllers/ChatHistoryController.cs +++ b/webapi/Controllers/ChatHistoryController.cs @@ -128,7 +128,7 @@ public async Task CreateChatSessionAsync( public async Task GetChatSessionByIdAsync(Guid chatId) { ChatSession? chat = null; - if (await this._sessionRepository.TryFindByIdAsync(chatId.ToString(), v => chat = v)) + if (await this._sessionRepository.TryFindByIdAsync(chatId.ToString(), callback: v => chat = v)) { return this.Ok(chat); } @@ -156,7 +156,7 @@ public async Task GetAllChatSessionsAsync(string userId) foreach (var chatParticipant in chatParticipants) { ChatSession? chat = null; - if (await this._sessionRepository.TryFindByIdAsync(chatParticipant.ChatId, v => chat = v)) + if (await this._sessionRepository.TryFindByIdAsync(chatParticipant.ChatId, callback: v => chat = v)) { chats.Add(chat!); } @@ -231,7 +231,7 @@ public async Task EditChatSessionAsync( } ChatSession? chat = null; - if (await this._sessionRepository.TryFindByIdAsync(chatId, v => chat = v)) + if (await this._sessionRepository.TryFindByIdAsync(chatId, callback: v => chat = v)) { chat!.Title = chatParameters.Title ?? chat!.Title; chat!.SystemDescription = chatParameters.SystemDescription ?? chat!.SystemDescription; @@ -260,7 +260,7 @@ public async Task>> GetSourcesAsync( { this._logger.LogInformation("Get imported sources of chat session {0}", chatId); - if (await this._sessionRepository.TryFindByIdAsync(chatId.ToString(), v => _ = v)) + if (await this._sessionRepository.TryFindByIdAsync(chatId.ToString())) { var sources = await this._sourceRepository.FindByChatIdAsync(chatId.ToString()); return this.Ok(sources); diff --git a/webapi/Controllers/ChatMemoryController.cs b/webapi/Controllers/ChatMemoryController.cs index 3f8e8d276..2190c93eb 100644 --- a/webapi/Controllers/ChatMemoryController.cs +++ b/webapi/Controllers/ChatMemoryController.cs @@ -67,7 +67,7 @@ public async Task GetSemanticMemoriesAsync( var sanitizedMemoryName = memoryName.Replace(Environment.NewLine, string.Empty, StringComparison.Ordinal); // Make sure the chat session exists. - if (!await this._chatSessionRepository.TryFindByIdAsync(chatId, v => _ = v)) + if (!await this._chatSessionRepository.TryFindByIdAsync(chatId)) { this._logger.LogWarning("Chat session: {0} does not exist.", sanitizedChatId); return this.BadRequest($"Chat session: {sanitizedChatId} does not exist."); diff --git a/webapi/Controllers/ChatParticipantController.cs b/webapi/Controllers/ChatParticipantController.cs index 18629ab71..62b5b6ee8 100644 --- a/webapi/Controllers/ChatParticipantController.cs +++ b/webapi/Controllers/ChatParticipantController.cs @@ -64,7 +64,7 @@ public async Task JoinChatAsync( string userId = authInfo.UserId; // Make sure the chat session exists. - if (!await this._chatSessionRepository.TryFindByIdAsync(chatId, v => _ = v)) + if (!await this._chatSessionRepository.TryFindByIdAsync(chatId)) { return this.BadRequest("Chat session does not exist."); } @@ -97,7 +97,7 @@ public async Task JoinChatAsync( public async Task GetAllParticipantsAsync(Guid chatId) { // Make sure the chat session exists. - if (!await this._chatSessionRepository.TryFindByIdAsync(chatId.ToString(), v => _ = v)) + if (!await this._chatSessionRepository.TryFindByIdAsync(chatId.ToString())) { return this.NotFound("Chat session does not exist."); } diff --git a/webapi/Models/Storage/ChatMessage.cs b/webapi/Models/Storage/ChatMessage.cs index 32ff183ca..1ecdc6294 100644 --- a/webapi/Models/Storage/ChatMessage.cs +++ b/webapi/Models/Storage/ChatMessage.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Globalization; using System.Text.Json; +using System.Text.Json.Serialization; using CopilotChat.WebApi.Models.Response; using CopilotChat.WebApi.Storage; @@ -109,6 +110,12 @@ public enum ChatMessageType /// public Dictionary? TokenUsage { get; set; } + /// + /// The partition key for the source. + /// + [JsonIgnore] + public string Partition => this.ChatId; + /// /// Create a new chat message. Timestamp is automatically generated. /// diff --git a/webapi/Models/Storage/ChatParticipant.cs b/webapi/Models/Storage/ChatParticipant.cs index 732c94494..11bd86e05 100644 --- a/webapi/Models/Storage/ChatParticipant.cs +++ b/webapi/Models/Storage/ChatParticipant.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Text.Json.Serialization; using CopilotChat.WebApi.Storage; namespace CopilotChat.WebApi.Models.Storage; @@ -26,6 +27,12 @@ public class ChatParticipant : IStorageEntity /// public string ChatId { get; set; } + /// + /// The partition key for the source. + /// + [JsonIgnore] + public string Partition => this.UserId; + public ChatParticipant(string userId, string chatId) { this.Id = Guid.NewGuid().ToString(); diff --git a/webapi/Models/Storage/ChatSession.cs b/webapi/Models/Storage/ChatSession.cs index 639c5ef86..329358742 100644 --- a/webapi/Models/Storage/ChatSession.cs +++ b/webapi/Models/Storage/ChatSession.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Text.Json.Serialization; using CopilotChat.WebApi.Storage; namespace CopilotChat.WebApi.Models.Storage; @@ -37,6 +38,12 @@ public class ChatSession : IStorageEntity /// public double MemoryBalance { get; set; } = 0.5; + /// + /// The partition key for the session. + /// + [JsonIgnore] + public string Partition => this.Id; + /// /// Initializes a new instance of the class. /// diff --git a/webapi/Models/Storage/MemorySource.cs b/webapi/Models/Storage/MemorySource.cs index 1b11003aa..a9488ba46 100644 --- a/webapi/Models/Storage/MemorySource.cs +++ b/webapi/Models/Storage/MemorySource.cs @@ -75,6 +75,12 @@ public class MemorySource : IStorageEntity [JsonPropertyName("tokens")] public long Tokens { get; set; } = 0; + /// + /// The partition key for the source. + /// + [JsonIgnore] + public string Partition => this.ChatId; + /// /// Empty constructor for serialization. /// diff --git a/webapi/Skills/ChatSkills/ChatSkill.cs b/webapi/Skills/ChatSkills/ChatSkill.cs index 2c11b1d95..fdb46c71a 100644 --- a/webapi/Skills/ChatSkills/ChatSkill.cs +++ b/webapi/Skills/ChatSkills/ChatSkill.cs @@ -564,7 +564,7 @@ private async Task AcquireExternalInformationAsync(SKContext context, st private async Task SaveNewMessageAsync(string message, string userId, string userName, string chatId, string type) { // Make sure the chat exists. - if (!await this._chatSessionRepository.TryFindByIdAsync(chatId, v => _ = v)) + if (!await this._chatSessionRepository.TryFindByIdAsync(chatId)) { throw new ArgumentException("Chat session does not exist."); } @@ -597,7 +597,7 @@ private async Task SaveNewMessageAsync(string message, string userI private async Task SaveNewResponseAsync(string response, string prompt, string chatId, string userId, Dictionary? tokenUsage) { // Make sure the chat exists. - if (!await this._chatSessionRepository.TryFindByIdAsync(chatId, v => _ = v)) + if (!await this._chatSessionRepository.TryFindByIdAsync(chatId)) { throw new ArgumentException("Chat session does not exist."); } @@ -775,7 +775,7 @@ private async Task UpdateBotResponseStatusOnClient(string chatId, string status) private async Task SetSystemDescriptionAsync(string chatId) { ChatSession? chatSession = null; - if (!await this._chatSessionRepository.TryFindByIdAsync(chatId, v => chatSession = v)) + if (!await this._chatSessionRepository.TryFindByIdAsync(chatId, callback: v => chatSession = v)) { throw new ArgumentException("Chat session does not exist."); } diff --git a/webapi/Skills/ChatSkills/SemanticChatMemorySkill.cs b/webapi/Skills/ChatSkills/SemanticChatMemorySkill.cs index 25f5f43fa..ae464a398 100644 --- a/webapi/Skills/ChatSkills/SemanticChatMemorySkill.cs +++ b/webapi/Skills/ChatSkills/SemanticChatMemorySkill.cs @@ -56,7 +56,7 @@ public async Task QueryMemoriesAsync( ISemanticTextMemory textMemory) { ChatSession? chatSession = null; - if (!await this._chatSessionRepository.TryFindByIdAsync(chatId, v => chatSession = v)) + if (!await this._chatSessionRepository.TryFindByIdAsync(chatId, callback: v => chatSession = v)) { throw new ArgumentException($"Chat session {chatId} not found."); } diff --git a/webapi/Storage/CosmosDbContext.cs b/webapi/Storage/CosmosDbContext.cs index 6a8e6264d..f3ccb8bc5 100644 --- a/webapi/Storage/CosmosDbContext.cs +++ b/webapi/Storage/CosmosDbContext.cs @@ -70,11 +70,11 @@ public async Task DeleteAsync(T entity) throw new ArgumentOutOfRangeException(nameof(entity.Id), "Entity Id cannot be null or empty."); } - await this._container.DeleteItemAsync(entity.Id, new PartitionKey(entity.Id)); + await this._container.DeleteItemAsync(entity.Id, new PartitionKey(entity.Partition)); } /// - public async Task ReadAsync(string entityId) + public async Task ReadAsync(string entityId, string partitionKey) { if (string.IsNullOrWhiteSpace(entityId)) { @@ -83,7 +83,7 @@ public async Task ReadAsync(string entityId) try { - var response = await this._container.ReadItemAsync(entityId, new PartitionKey(entityId)); + var response = await this._container.ReadItemAsync(entityId, new PartitionKey(partitionKey)); return response.Resource; } catch (CosmosException ex) when (ex.StatusCode == HttpStatusCode.NotFound) diff --git a/webapi/Storage/FileSystemContext.cs b/webapi/Storage/FileSystemContext.cs index 8946231d7..44803f1b8 100644 --- a/webapi/Storage/FileSystemContext.cs +++ b/webapi/Storage/FileSystemContext.cs @@ -65,7 +65,7 @@ public Task DeleteAsync(T entity) } /// - public Task ReadAsync(string entityId) + public Task ReadAsync(string entityId, string partitionKey) { if (string.IsNullOrWhiteSpace(entityId)) { diff --git a/webapi/Storage/IRepository.cs b/webapi/Storage/IRepository.cs index fb2e1e7f4..7e564afc8 100644 --- a/webapi/Storage/IRepository.cs +++ b/webapi/Storage/IRepository.cs @@ -32,14 +32,16 @@ public interface IRepository where T : IStorageEntity /// Finds an entity by its id. /// /// Id of the entity. + /// Partition of the entity. /// An entity - Task FindByIdAsync(string id); + Task FindByIdAsync(string id, string partition); /// /// Tries to find an entity by its id. /// /// Id of the entity. - /// The entity delegate. Note async methods don't support ref or out parameters. + /// Partition of the entity. + /// The entity delegate. Note async methods don't support ref or out parameters. /// True if the entity was found, false otherwise. - Task TryFindByIdAsync(string id, Action entity); + Task TryFindByIdAsync(string id, string partition, Action callback); } diff --git a/webapi/Storage/IStorageContext.cs b/webapi/Storage/IStorageContext.cs index 54e350d37..c8f55dc06 100644 --- a/webapi/Storage/IStorageContext.cs +++ b/webapi/Storage/IStorageContext.cs @@ -20,8 +20,9 @@ public interface IStorageContext where T : IStorageEntity /// Read an entity from the storage context by id. /// /// The entity id. + /// The entity partition /// The entity. - Task ReadAsync(string entityId); + Task ReadAsync(string entityId, string partitionKey); /// /// Create an entity in the storage context. diff --git a/webapi/Storage/IStorageEntity.cs b/webapi/Storage/IStorageEntity.cs index e1aa6ab72..91d5880ef 100644 --- a/webapi/Storage/IStorageEntity.cs +++ b/webapi/Storage/IStorageEntity.cs @@ -5,4 +5,6 @@ namespace CopilotChat.WebApi.Storage; public interface IStorageEntity { string Id { get; set; } + + string Partition { get; } } diff --git a/webapi/Storage/Repository.cs b/webapi/Storage/Repository.cs index d6add3cad..4cdad43fa 100644 --- a/webapi/Storage/Repository.cs +++ b/webapi/Storage/Repository.cs @@ -42,22 +42,21 @@ public Task DeleteAsync(T entity) } /// - public Task FindByIdAsync(string id) + public Task FindByIdAsync(string id, string? partition = null) { - return this.StorageContext.ReadAsync(id); + return this.StorageContext.ReadAsync(id, partition ?? id); } /// - public async Task TryFindByIdAsync(string id, Action entity) + public async Task TryFindByIdAsync(string id, string? partition = null, Action? callback = null) { try { - entity(await this.FindByIdAsync(id)); + callback?.Invoke(await this.FindByIdAsync(id, partition ?? id)); return true; } catch (Exception ex) when (ex is ArgumentOutOfRangeException || ex is KeyNotFoundException) { - entity(default); return false; } } diff --git a/webapi/Storage/VolatileContext.cs b/webapi/Storage/VolatileContext.cs index 4cf7a82aa..122b639d0 100644 --- a/webapi/Storage/VolatileContext.cs +++ b/webapi/Storage/VolatileContext.cs @@ -61,7 +61,7 @@ public Task DeleteAsync(T entity) } /// - public Task ReadAsync(string entityId) + public Task ReadAsync(string entityId, string partitionKey) { if (string.IsNullOrWhiteSpace(entityId)) {