From 11e57799e3c5ca95f79f83acf6da2d5a067d7555 Mon Sep 17 00:00:00 2001 From: SuZhou-Joe Date: Tue, 12 Dec 2023 16:18:00 +0800 Subject: [PATCH] feat: integrate regenerate API (#58) * feat: integrate regenerate API Signed-off-by: SuZhou-Joe * feat: optimize code Signed-off-by: SuZhou-Joe * fix: unit test error Signed-off-by: SuZhou-Joe * feat: optimize code Signed-off-by: SuZhou-Joe * feat: optimize code Signed-off-by: SuZhou-Joe --------- Signed-off-by: SuZhou-Joe --- public/hooks/use_chat_actions.tsx | 8 +- .../chat/messages/message_bubble.test.tsx | 7 + public/tabs/chat/messages/message_bubble.tsx | 8 +- public/types.ts | 2 +- server/routes/chat_routes.ts | 45 ++--- server/services/chat/chat_service.ts | 11 +- .../services/chat/olly_chat_service.test.ts | 176 ++++++++++++++++++ server/services/chat/olly_chat_service.ts | 84 ++++++--- server/types.ts | 2 +- 9 files changed, 278 insertions(+), 65 deletions(-) create mode 100644 server/services/chat/olly_chat_service.test.ts diff --git a/public/hooks/use_chat_actions.tsx b/public/hooks/use_chat_actions.tsx index c78bff31..fd5f3f4a 100644 --- a/public/hooks/use_chat_actions.tsx +++ b/public/hooks/use_chat_actions.tsx @@ -162,7 +162,7 @@ export const useChatActions = (): AssistantActions => { } }; - const regenerate = async () => { + const regenerate = async (interactionId: string) => { if (chatContext.sessionId) { const abortController = new AbortController(); abortControllerRef = abortController; @@ -170,7 +170,11 @@ export const useChatActions = (): AssistantActions => { try { const response = await core.services.http.put(`${ASSISTANT_API.REGENERATE}`, { - body: JSON.stringify({ sessionId: chatContext.sessionId }), + body: JSON.stringify({ + sessionId: chatContext.sessionId, + rootAgentId: chatContext.rootAgentId, + interactionId, + }), }); if (abortController.signal.aborted) { diff --git a/public/tabs/chat/messages/message_bubble.test.tsx b/public/tabs/chat/messages/message_bubble.test.tsx index 4293dec4..d4652f32 100644 --- a/public/tabs/chat/messages/message_bubble.test.tsx +++ b/public/tabs/chat/messages/message_bubble.test.tsx @@ -130,6 +130,13 @@ describe('', () => { contentType: 'markdown', content: 'here are the indices in your cluster: .alert', }} + interaction={{ + input: 'foo', + response: 'bar', + conversation_id: 'foo', + interaction_id: 'bar', + create_time: new Date().toLocaleString(), + }} /> ); expect(screen.queryAllByTitle('regenerate message')).toHaveLength(1); diff --git a/public/tabs/chat/messages/message_bubble.tsx b/public/tabs/chat/messages/message_bubble.tsx index 34199640..8a704e2a 100644 --- a/public/tabs/chat/messages/message_bubble.tsx +++ b/public/tabs/chat/messages/message_bubble.tsx @@ -30,7 +30,7 @@ type MessageBubbleProps = { showActionBar: boolean; showRegenerate?: boolean; shouldActionBarVisibleOnHover?: boolean; - onRegenerate?: () => void; + onRegenerate?: (interactionId: string) => void; } & ( | { message: IMessage; @@ -192,17 +192,17 @@ export const MessageBubble: React.FC = React.memo((props) => )} - {props.showRegenerate && ( + {props.showRegenerate && props.interaction?.interaction_id ? ( props.onRegenerate?.(props.interaction?.interaction_id || '')} title="regenerate message" color="text" iconType="refresh" /> - )} + ) : null} {showFeedback && ( // After feedback, only corresponding thumb icon will be kept and disabled. <> diff --git a/public/types.ts b/public/types.ts index d8f152a3..5fd027e2 100644 --- a/public/types.ts +++ b/public/types.ts @@ -16,7 +16,7 @@ export interface AssistantActions { openChatUI: (sessionId?: string) => void; executeAction: (suggestedAction: ISuggestedAction, message: IMessage) => void; abortAction: (sessionId?: string) => void; - regenerate: () => void; + regenerate: (interactionId: string) => void; } export interface AppPluginStartDependencies { diff --git a/server/routes/chat_routes.ts b/server/routes/chat_routes.ts index 22328032..5e195820 100644 --- a/server/routes/chat_routes.ts +++ b/server/routes/chat_routes.ts @@ -13,7 +13,6 @@ import { } from '../../../../src/core/server'; import { ASSISTANT_API } from '../../common/constants/llm'; import { OllyChatService } from '../services/chat/olly_chat_service'; -import { IMessage, IInput } from '../../common/types/chat_saved_object_attributes'; import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service'; import { RoutesOptions } from '../types'; import { ChatService } from '../services/chat/chat_service'; @@ -64,6 +63,7 @@ const regenerateRoute = { body: schema.object({ sessionId: schema.string(), rootAgentId: schema.string(), + interactionId: schema.string(), }), }, }; @@ -314,42 +314,35 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const { sessionId, rootAgentId } = request.body; + const { sessionId, rootAgentId, interactionId } = request.body; const storageService = createStorageService(context); - let messages: IMessage[] = []; const chatService = createChatService(); + let outputs: Awaited> | undefined; + + /** + * Get final answer from Agent framework + */ try { - const session = await storageService.getSession(sessionId); - messages.push(...session.messages); + outputs = await chatService.regenerate({ sessionId, rootAgentId, interactionId }, context); } catch (error) { - return response.custom({ statusCode: error.statusCode || 500, body: error.message }); + context.assistant_plugin.logger.error(error); } - const lastInputIndex = messages.findLastIndex((msg) => msg.type === 'input'); - // Find last input message - const input = messages[lastInputIndex] as IInput; - // Take the messages before last input message as memory as regenerate will exclude the last outputs - messages = messages.slice(0, lastInputIndex); - + /** + * Retrieve latest interactions from memory + */ try { - const outputs = await chatService.requestLLM( - { messages, input, sessionId, rootAgentId }, - context - ); - const title = input.content.substring(0, 50); - const saveMessagesResponse = await storageService.saveMessages( - title, - sessionId, - [...messages, input, ...outputs.messages].filter( - (message) => message.content !== 'AbortError' - ) - ); + const conversation = await storageService.getSession(sessionId); + return response.ok({ - body: { ...saveMessagesResponse, title }, + body: { + ...conversation, + sessionId, + }, }); } catch (error) { - context.assistant_plugin.logger.warn(error); + context.assistant_plugin.logger.error(error); return response.custom({ statusCode: error.statusCode || 500, body: error.message }); } } diff --git a/server/services/chat/chat_service.ts b/server/services/chat/chat_service.ts index ac15adf6..25fe703f 100644 --- a/server/services/chat/chat_service.ts +++ b/server/services/chat/chat_service.ts @@ -10,8 +10,15 @@ import { LLMRequestSchema } from '../../routes/chat_routes'; export interface ChatService { requestLLM( payload: { messages: IMessage[]; input: IInput; sessionId?: string }, - context: RequestHandlerContext, - request: OpenSearchDashboardsRequest + context: RequestHandlerContext + ): Promise<{ + messages: IMessage[]; + memoryId: string; + }>; + + regenerate( + payload: { sessionId: string; interactionId: string; rootAgentId: string }, + context: RequestHandlerContext ): Promise<{ messages: IMessage[]; memoryId: string; diff --git a/server/services/chat/olly_chat_service.test.ts b/server/services/chat/olly_chat_service.test.ts new file mode 100644 index 00000000..1d5f563d --- /dev/null +++ b/server/services/chat/olly_chat_service.test.ts @@ -0,0 +1,176 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { OllyChatService } from './olly_chat_service'; +import { CoreRouteHandlerContext } from '../../../../../src/core/server/core_route_handler_context'; +import { coreMock, httpServerMock } from '../../../../../src/core/server/mocks'; +import { loggerMock } from '../../../../../src/core/server/logging/logger.mock'; + +describe('OllyChatService', () => { + const ollyChatService = new OllyChatService(); + const coreContext = new CoreRouteHandlerContext( + coreMock.createInternalStart(), + httpServerMock.createOpenSearchDashboardsRequest() + ); + const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport + .request as jest.Mock; + const contextMock = { + core: coreContext, + assistant_plugin: { + logger: loggerMock.create(), + }, + }; + beforeEach(() => { + mockedTransport.mockClear(); + }); + it('requestLLM should invoke client call with correct params', async () => { + mockedTransport.mockImplementationOnce(() => { + return { + body: { + inference_results: [ + { + output: [ + { + name: 'memory_id', + result: 'foo', + }, + ], + }, + ], + }, + }; + }); + const result = await ollyChatService.requestLLM( + { + messages: [], + input: { + type: 'input', + contentType: 'text', + content: 'content', + }, + sessionId: '', + rootAgentId: 'rootAgentId', + }, + contextMock + ); + expect(mockedTransport.mock.calls).toMatchInlineSnapshot(` + Array [ + Array [ + Object { + "body": Object { + "parameters": Object { + "question": "content", + "verbose": true, + }, + }, + "method": "POST", + "path": "/_plugins/_ml/agents/rootAgentId/_execute", + }, + Object { + "maxRetries": 0, + "requestTimeout": 300000, + }, + ], + ] + `); + expect(result).toMatchInlineSnapshot(` + Object { + "memoryId": "foo", + "messages": Array [], + } + `); + }); + + it('requestLLM should throw error when transport.request throws error', async () => { + mockedTransport.mockImplementationOnce(() => { + throw new Error('error'); + }); + expect( + ollyChatService.requestLLM( + { + messages: [], + input: { + type: 'input', + contentType: 'text', + content: 'content', + }, + sessionId: '', + rootAgentId: 'rootAgentId', + }, + contextMock + ) + ).rejects.toMatchInlineSnapshot(`[Error: error]`); + }); + + it('regenerate should invoke client call with correct params', async () => { + mockedTransport.mockImplementationOnce(() => { + return { + body: { + inference_results: [ + { + output: [ + { + name: 'memory_id', + result: 'foo', + }, + ], + }, + ], + }, + }; + }); + const result = await ollyChatService.regenerate( + { + sessionId: 'sessionId', + rootAgentId: 'rootAgentId', + interactionId: 'interactionId', + }, + contextMock + ); + expect(mockedTransport.mock.calls).toMatchInlineSnapshot(` + Array [ + Array [ + Object { + "body": Object { + "parameters": Object { + "memory_id": "sessionId", + "regenerate_interaction_id": "interactionId", + "verbose": true, + }, + }, + "method": "POST", + "path": "/_plugins/_ml/agents/rootAgentId/_execute", + }, + Object { + "maxRetries": 0, + "requestTimeout": 300000, + }, + ], + ] + `); + expect(result).toMatchInlineSnapshot(` + Object { + "memoryId": "foo", + "messages": Array [], + } + `); + }); + + it('regenerate should throw error when transport.request throws error', async () => { + mockedTransport.mockImplementationOnce(() => { + throw new Error('error'); + }); + expect( + ollyChatService.regenerate( + { + sessionId: 'sessionId', + rootAgentId: 'rootAgentId', + interactionId: 'interactionId', + }, + contextMock + ) + ).rejects.toMatchInlineSnapshot(`[Error: error]`); + }); +}); diff --git a/server/services/chat/olly_chat_service.ts b/server/services/chat/olly_chat_service.ts index 13981eda..ed2cb57b 100644 --- a/server/services/chat/olly_chat_service.ts +++ b/server/services/chat/olly_chat_service.ts @@ -9,46 +9,35 @@ import { IMessage, IInput } from '../../../common/types/chat_saved_object_attrib import { ChatService } from './chat_service'; import { ML_COMMONS_BASE_API } from '../../utils/constants'; +interface AgentRunPayload { + question?: string; + verbose?: boolean; + memory_id?: string; + regenerate_interaction_id?: string; +} + const MEMORY_ID_FIELD = 'memory_id'; export class OllyChatService implements ChatService { static abortControllers: Map = new Map(); - public async requestLLM( - payload: { messages: IMessage[]; input: IInput; sessionId?: string; rootAgentId: string }, + private async requestAgentRun( + rootAgentId: string, + payload: AgentRunPayload, context: RequestHandlerContext - ): Promise<{ - messages: IMessage[]; - memoryId: string; - }> { - const { input, sessionId, rootAgentId } = payload; - const opensearchClient = context.core.opensearch.client.asCurrentUser; - - if (payload.sessionId) { - OllyChatService.abortControllers.set(payload.sessionId, new AbortController()); + ) { + if (payload.memory_id) { + OllyChatService.abortControllers.set(payload.memory_id, new AbortController()); } + const opensearchClient = context.core.opensearch.client.asCurrentUser; try { - /** - * Wait for an API to fetch root agent id. - */ - const parametersPayload: { - question: string; - verbose?: boolean; - memory_id?: string; - } = { - question: input.content, - verbose: true, - }; - if (sessionId) { - parametersPayload.memory_id = sessionId; - } const agentFrameworkResponse = (await opensearchClient.transport.request( { method: 'POST', path: `${ML_COMMONS_BASE_API}/agents/${rootAgentId}/_execute`, body: { - parameters: parametersPayload, + parameters: payload, }, }, { @@ -69,7 +58,6 @@ export class OllyChatService implements ChatService { }>; const outputBody = agentFrameworkResponse.body.inference_results?.[0]?.output; const memoryIdItem = outputBody?.find((item) => item.name === MEMORY_ID_FIELD); - return { /** * Interactions will be stored in Agent framework, @@ -81,12 +69,50 @@ export class OllyChatService implements ChatService { } catch (error) { throw error; } finally { - if (payload.sessionId) { - OllyChatService.abortControllers.delete(payload.sessionId); + if (payload.memory_id) { + OllyChatService.abortControllers.delete(payload.memory_id); } } } + public async requestLLM( + payload: { messages: IMessage[]; input: IInput; sessionId?: string; rootAgentId: string }, + context: RequestHandlerContext + ): Promise<{ + messages: IMessage[]; + memoryId: string; + }> { + const { input, sessionId, rootAgentId } = payload; + + const parametersPayload: Pick = { + question: input.content, + verbose: true, + }; + + if (sessionId) { + parametersPayload.memory_id = sessionId; + } + + return await this.requestAgentRun(rootAgentId, parametersPayload, context); + } + + async regenerate( + payload: { sessionId: string; interactionId: string; rootAgentId: string }, + context: RequestHandlerContext + ): Promise<{ messages: IMessage[]; memoryId: string }> { + const { sessionId, interactionId, rootAgentId } = payload; + const parametersPayload: Pick< + AgentRunPayload, + 'regenerate_interaction_id' | 'verbose' | 'memory_id' + > = { + memory_id: sessionId, + regenerate_interaction_id: interactionId, + verbose: true, + }; + + return await this.requestAgentRun(rootAgentId, parametersPayload, context); + } + abortAgentExecution(sessionId: string) { if (OllyChatService.abortControllers.has(sessionId)) { OllyChatService.abortControllers.get(sessionId)?.abort(); diff --git a/server/types.ts b/server/types.ts index 948ed5aa..3e93ce8f 100644 --- a/server/types.ts +++ b/server/types.ts @@ -4,7 +4,7 @@ */ import { IMessage, Interaction } from '../common/types/chat_saved_object_attributes'; -import { ILegacyClusterClient, Logger } from '../../../src/core/server'; +import { Logger } from '../../../src/core/server'; // eslint-disable-next-line @typescript-eslint/no-empty-interface export interface AssistantPluginSetup {}