From 58c4dfc1cbf8047bf647fb69cf0f954c84cb0575 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Tue, 14 Nov 2023 09:46:50 +0100 Subject: [PATCH 01/15] [Obs AI Assistant] UseChat hook --- .../common/chat/streaming.ts | 74 +++++++ .../common/errors/index.ts | 23 +++ .../public/hooks/use_chat.test.ts | 58 ++++++ .../public/hooks/use_chat.ts | 183 ++++++++++++++++++ 4 files changed, 338 insertions(+) create mode 100644 x-pack/plugins/observability_ai_assistant/common/chat/streaming.ts create mode 100644 x-pack/plugins/observability_ai_assistant/common/errors/index.ts create mode 100644 x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts create mode 100644 x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts diff --git a/x-pack/plugins/observability_ai_assistant/common/chat/streaming.ts b/x-pack/plugins/observability_ai_assistant/common/chat/streaming.ts new file mode 100644 index 00000000000000..6ee1751064f09c --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/common/chat/streaming.ts @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { ChatCompletionError as ChatCompletionErrorClass } from '../errors'; +import { Message } from '../types'; + +export enum StreamingChatResponseEventType { + ChatCompletionChunk = 'chatCompletionChunk', + ConversationCreate = 'conversationCreate', + ConversationUpdate = 'conversationUpdate', + MessageAdd = 'messageAdd', + ChatCompletionError = 'chatCompletionError', +} + +type StreamingChatResponseEventBase< + TEventType extends StreamingChatResponseEventType, + TData extends {} +> = { + type: TEventType; +} & TData; + +type ChatCompletionChunkEvent = StreamingChatResponseEventBase< + StreamingChatResponseEventType.ChatCompletionChunk, + { + message: { + content?: string; + function_call?: { + name?: string; + args?: string; + }; + }; + } +>; + +type ConversationCreateEvent = StreamingChatResponseEventBase< + StreamingChatResponseEventType.ConversationCreate, + { + conversation: { + id: string; + }; + } +>; + +type ConversationUpdateEvent = StreamingChatResponseEventBase< + StreamingChatResponseEventType.ConversationUpdate, + { + conversation: { + id: string; + title: string; + last_updated: string; + }; + } +>; + +type MessageAddEvent = StreamingChatResponseEventBase< + StreamingChatResponseEventType.MessageAdd, + Message +>; + +type ChatCompletionErrorEvent = StreamingChatResponseEventBase< + StreamingChatResponseEventType.ChatCompletionError, + typeof ChatCompletionErrorClass +>; + +export type StreamingChatResponseEvent = + | ChatCompletionChunkEvent + | ConversationCreateEvent + | ConversationUpdateEvent + | MessageAddEvent + | ChatCompletionErrorEvent; diff --git a/x-pack/plugins/observability_ai_assistant/common/errors/index.ts b/x-pack/plugins/observability_ai_assistant/common/errors/index.ts new file mode 100644 index 00000000000000..6625b0c9cf9a12 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/common/errors/index.ts @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export enum ChatCompletionErrorCode { + InternalError = 'internalError', +} + +export class ChatCompletionError extends Error { + code: ChatCompletionErrorCode; + + constructor(code: ChatCompletionErrorCode, message: string) { + super(message); + this.code = code; + } +} + +export function isChatCompletionError(error: Error): error is ChatCompletionError { + return error instanceof ChatCompletionError; +} diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts new file mode 100644 index 00000000000000..70b97a601644cc --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { useChat } from './use_chat'; + +describe('useChat', () => { + // Tests for initial hook setup and default states + + describe('initially', () => { + it('returns the initial messages including the system message', () => {}); + it('sets chatState to ready', () => {}); + }); + + describe('when calling next()', () => { + it('sets the chatState to loading', () => {}); + + describe('after a partial response it updates the returned messages', () => {}); + + describe('after a completed response it updates the returned messages and the loading state', () => {}); + + describe('after aborting a response it shows the partial message and sets chatState to aborted', () => {}); + + describe('after a response errors out, it shows the partial message and sets chatState to error', () => {}); + }); + + // Tests for the 'next' function behavior + describe('Function next', () => { + it('should handle empty message array correctly', () => {}); + it('should ignore non-user and non-assistant messages with function requests', () => {}); + it('should set chat state to loading on valid next message', () => {}); + it('should handle abort signal correctly during message processing', () => {}); + it('should handle message processing for assistant messages with function request', () => {}); + it('should handle message processing for user messages', () => {}); + it('should handle observable responses correctly', () => {}); + it('should update messages correctly after response', () => {}); + it('should handle errors during message processing', () => {}); + }); + + // Tests for the 'stop' function behavior + describe('Function stop', () => { + it('should abort current operation when stop is called', () => {}); + }); + + // Tests for the state management within the hook + describe('State management', () => { + it('should update chat state correctly', () => {}); + it('should update messages state correctly', () => {}); + it('should handle pending message state correctly', () => {}); + }); + + // Tests for cleanup and unmounting behavior + describe('Cleanup and unmounting behavior', () => { + it('should abort any ongoing process on unmount', () => {}); + }); +}); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts new file mode 100644 index 00000000000000..4805824668344f --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts @@ -0,0 +1,183 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { i18n } from '@kbn/i18n'; +import { last } from 'lodash'; +import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { isObservable } from 'rxjs'; +import { type Message, MessageRole } from '../../common'; +import { getAssistantSetupMessage } from '../service/get_assistant_setup_message'; +import type { ObservabilityAIAssistantChatService, PendingMessage } from '../types'; +import { useKibana } from './use_kibana'; + +export enum ChatState { + Ready = 'ready', + Loading = 'loading', + Error = 'error', + Aborted = 'aborted', +} + +interface UseChatResult { + messages: Message[]; + state: ChatState; + next: (messages: Message[]) => void; + stop: () => void; +} + +export function useChat({ + initialMessages, + chatService, + connectorId, +}: { + initialMessages: Message[]; + chatService: ObservabilityAIAssistantChatService; + connectorId: string; +}): UseChatResult { + const [chatState, setChatState] = useState(ChatState.Ready); + + const [messages, setMessages] = useState(initialMessages); + + const [pendingMessage, setPendingMessage] = useState(); + + const abortControllerRef = useRef(new AbortController()); + + const { + services: { notifications }, + } = useKibana(); + + const handleSignalAbort = useCallback(() => { + setChatState(ChatState.Aborted); + }, []); + + async function next(nextMessages: Message[]) { + abortControllerRef.current.signal.removeEventListener('abort', handleSignalAbort); + + const lastMessage = last(nextMessages); + + if (!lastMessage) { + setChatState(ChatState.Ready); + return; + } + + const isUserMessage = lastMessage.message.role === MessageRole.User; + const functionCall = lastMessage.message.function_call; + const isAssistantMessageWithFunctionRequest = + lastMessage.message.role === MessageRole.Assistant && functionCall && !!functionCall?.name; + + if (!isUserMessage && !isAssistantMessageWithFunctionRequest) { + setChatState(ChatState.Ready); + return; + } + + const abortController = (abortControllerRef.current = new AbortController()); + + abortController.signal.addEventListener('abort', handleSignalAbort); + + setChatState(ChatState.Loading); + + const allMessages = [ + getAssistantSetupMessage({ contexts: chatService.getContexts() }), + ...nextMessages.filter((message) => message.message.role !== MessageRole.System), + ]; + + function handleError(error: Error) { + notifications.toasts.addError(error, { + title: i18n.translate('xpack.observabilityAiAssistant.failedToLoadResponse', { + defaultMessage: 'Failed to load response from the AI Assistant', + }), + }); + } + + const response = isAssistantMessageWithFunctionRequest + ? await chatService + .executeFunction({ + name: functionCall.name, + signal: abortController.signal, + args: functionCall.arguments, + connectorId, + messages: allMessages, + }) + .catch((error) => { + return { + content: JSON.stringify({ + message: error.toString(), + error, + }), + data: undefined, + }; + }) + : chatService.chat({ + messages: allMessages, + connectorId, + }); + + if (isObservable(response)) { + const localPendingMessage = pendingMessage!; + const subscription = response.subscribe({ + next: (nextPendingMessage) => { + setPendingMessage(nextPendingMessage); + }, + complete: () => { + setPendingMessage(undefined); + const allMessagesWithResolved = allMessages.concat({ + message: { + ...localPendingMessage.message, + }, + '@timestamp': new Date().toISOString(), + }); + setMessages(allMessagesWithResolved); + if (localPendingMessage.aborted) { + setChatState(ChatState.Aborted); + } else if (localPendingMessage.error) { + handleError(localPendingMessage.error); + } else { + next(allMessagesWithResolved); + } + }, + error: (error) => { + handleError(error); + }, + }); + + abortController.signal.addEventListener('abort', () => { + subscription.unsubscribe(); + }); + } else { + const allMessagesWithFunctionReply = allMessages.concat({ + '@timestamp': new Date().toISOString(), + message: { + name: functionCall!.name, + role: MessageRole.User, + content: JSON.stringify(response.content), + data: JSON.stringify(response.data), + }, + }); + next(allMessagesWithFunctionReply); + } + } + + useEffect(() => { + return () => { + abortControllerRef.current.abort(); + }; + }, []); + + const memoizedMessages = useMemo(() => { + return pendingMessage + ? messages.concat({ ...pendingMessage, '@timestamp': new Date().toISOString() }) + : messages; + }, [messages, pendingMessage]); + + return { + messages: memoizedMessages, + state: chatState, + next, + stop: () => { + abortControllerRef.current.abort(); + }, + }; +} From 97f3f6e5b6a219889432f621c6d3b07db4fbe2a8 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Mon, 27 Nov 2023 11:02:18 +0100 Subject: [PATCH 02/15] [Obs AI Assistant] refactor hooks, recall on every user message --- .../common/index.ts | 4 + .../observability_ai_assistant/jest.config.js | 2 + .../action_menu_item/action_menu_item.tsx | 35 +- .../components/chat/chat_body.stories.tsx | 6 +- .../public/components/chat/chat_body.tsx | 161 +++-- .../chat/chat_consolidated_items.tsx | 6 +- .../components/chat/chat_flyout.stories.tsx | 7 +- .../public/components/chat/chat_flyout.tsx | 42 +- .../public/components/chat/chat_item.tsx | 7 +- ...chat_item_content_inline_prompt_editor.tsx | 2 +- .../components/chat/chat_prompt_editor.tsx | 16 +- .../components/chat/chat_timeline.stories.tsx | 103 ++- .../public/components/chat/chat_timeline.tsx | 77 ++- .../public/components/chat/types.ts | 2 +- .../public/components/insight/insight.tsx | 214 +----- .../message_panel/message_panel.stories.tsx | 4 +- .../message_panel/message_panel.tsx | 2 +- .../public/hooks/use_abortable_async.ts | 6 +- .../public/hooks/use_chat.test.ts | 474 +++++++++++++- .../public/hooks/use_chat.ts | 283 +++++--- .../public/hooks/use_conversation.test.tsx | 536 +++++++++++++++ .../public/hooks/use_conversation.ts | 279 ++++---- .../public/hooks/use_knowledge_base.tsx | 1 + .../public/hooks/use_once.ts | 21 + .../public/hooks/use_timeline.test.ts | 611 ------------------ .../public/hooks/use_timeline.ts | 387 ----------- .../public/routes/config.tsx | 15 +- .../conversations/conversation_view.tsx | 143 ++-- .../service/create_mock_chat_service.ts | 24 + .../public/utils/builders.ts | 160 +++-- ..._timeline_items_from_conversation.test.tsx | 597 +++++++++++++++++ .../get_timeline_items_from_conversation.tsx | 102 ++- 32 files changed, 2473 insertions(+), 1856 deletions(-) create mode 100644 x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx create mode 100644 x-pack/plugins/observability_ai_assistant/public/hooks/use_once.ts delete mode 100644 x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts delete mode 100644 x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts create mode 100644 x-pack/plugins/observability_ai_assistant/public/service/create_mock_chat_service.ts create mode 100644 x-pack/plugins/observability_ai_assistant/public/utils/get_timeline_items_from_conversation.test.tsx diff --git a/x-pack/plugins/observability_ai_assistant/common/index.ts b/x-pack/plugins/observability_ai_assistant/common/index.ts index 92cd91871da696..a4f181a127cfc8 100644 --- a/x-pack/plugins/observability_ai_assistant/common/index.ts +++ b/x-pack/plugins/observability_ai_assistant/common/index.ts @@ -7,3 +7,7 @@ export type { Message, Conversation } from './types'; export { MessageRole } from './types'; + +export { type StreamingChatResponseEvent, StreamingChatResponseEventType } from './chat/streaming'; + +export { ChatCompletionError, ChatCompletionErrorCode } from './errors'; diff --git a/x-pack/plugins/observability_ai_assistant/jest.config.js b/x-pack/plugins/observability_ai_assistant/jest.config.js index 5eaabe2dcf492c..1d6798f6c7623c 100644 --- a/x-pack/plugins/observability_ai_assistant/jest.config.js +++ b/x-pack/plugins/observability_ai_assistant/jest.config.js @@ -10,4 +10,6 @@ module.exports = { rootDir: '../../..', roots: ['/x-pack/plugins/observability_ai_assistant'], setupFiles: ['/x-pack/plugins/observability_ai_assistant/.storybook/jest_setup.js'], + collectCoverage: true, + coverageReporters: ['html'], }; diff --git a/x-pack/plugins/observability_ai_assistant/public/components/action_menu_item/action_menu_item.tsx b/x-pack/plugins/observability_ai_assistant/public/components/action_menu_item/action_menu_item.tsx index ca47c242df495e..e77b055e568fa3 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/action_menu_item/action_menu_item.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/action_menu_item/action_menu_item.tsx @@ -6,19 +6,15 @@ */ import { EuiFlexGroup, EuiFlexItem, EuiHeaderLink, EuiLoadingSpinner } from '@elastic/eui'; import { i18n } from '@kbn/i18n'; -import React, { useState } from 'react'; +import React, { useMemo, useState } from 'react'; import { ObservabilityAIAssistantChatServiceProvider } from '../../context/observability_ai_assistant_chat_service_provider'; import { useAbortableAsync } from '../../hooks/use_abortable_async'; -import { useConversation } from '../../hooks/use_conversation'; -import { useGenAIConnectors } from '../../hooks/use_genai_connectors'; import { useObservabilityAIAssistant } from '../../hooks/use_observability_ai_assistant'; -import { EMPTY_CONVERSATION_TITLE } from '../../i18n'; import { AssistantAvatar } from '../assistant_avatar'; import { ChatFlyout } from '../chat/chat_flyout'; export function ObservabilityAIAssistantActionMenuItem() { const service = useObservabilityAIAssistant(); - const connectors = useGenAIConnectors(); const [isOpen, setIsOpen] = useState(false); @@ -32,14 +28,7 @@ export function ObservabilityAIAssistantActionMenuItem() { [service, isOpen] ); - const [conversationId, setConversationId] = useState(); - - const { conversation, displayedMessages, setDisplayedMessages, save, saveTitle } = - useConversation({ - conversationId, - connectorId: connectors.selectedConnector, - chatService: chatService.value, - }); + const initialMessages = useMemo(() => [], []); if (!service.isEnabled()) { return null; @@ -72,26 +61,12 @@ export function ObservabilityAIAssistantActionMenuItem() { {chatService.value ? ( { - setIsOpen(() => false); - }} - onChatComplete={(messages) => { - save(messages) - .then((nextConversation) => { - setConversationId(nextConversation.conversation.id); - }) - .catch(() => {}); - }} - onChatUpdate={(nextMessages) => { - setDisplayedMessages(nextMessages); - }} - onChatTitleSave={(newTitle) => { - saveTitle(newTitle); + setIsOpen(false); }} /> diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.stories.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.stories.tsx index 4b8d749abf9f41..52c927795c4168 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.stories.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.stories.tsx @@ -21,8 +21,8 @@ const meta: ComponentMeta = { export default meta; const defaultProps: ComponentStoryObj = { args: { - title: 'My Conversation', - messages: [ + initialTitle: 'My Conversation', + initialMessages: [ getAssistantSetupMessage({ contexts: [] }), { '@timestamp': new Date().toISOString(), @@ -64,8 +64,6 @@ const defaultProps: ComponentStoryObj = { currentUser: { username: 'elastic', }, - onChatUpdate: () => {}, - onChatComplete: () => {}, }, render: (props) => { return ( diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx index 5e4aa5e0659eb6..88990b3a9e47ba 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx @@ -5,9 +5,8 @@ * 2.0. */ -import React, { useEffect, useRef, useState } from 'react'; -import { flatten, last } from 'lodash'; import { + EuiCallOut, EuiFlexGroup, EuiFlexItem, EuiHorizontalRule, @@ -16,21 +15,26 @@ import { EuiSpacer, } from '@elastic/eui'; import { css } from '@emotion/css'; -import { euiThemeVars } from '@kbn/ui-theme'; import type { AuthenticatedUser } from '@kbn/security-plugin/common'; -import type { Message } from '../../../common/types'; +import { euiThemeVars } from '@kbn/ui-theme'; +import React, { useEffect, useRef, useState } from 'react'; +import { i18n } from '@kbn/i18n'; +import { Conversation, Message, MessageRole } from '../../../common/types'; +import { ChatState } from '../../hooks/use_chat'; +import { useConversation } from '../../hooks/use_conversation'; import type { UseGenAIConnectorsResult } from '../../hooks/use_genai_connectors'; import type { UseKnowledgeBaseResult } from '../../hooks/use_knowledge_base'; -import { useTimeline } from '../../hooks/use_timeline'; import { useLicense } from '../../hooks/use_license'; import { useObservabilityAIAssistantChatService } from '../../hooks/use_observability_ai_assistant_chat_service'; -import { ExperimentalFeatureBanner } from './experimental_feature_banner'; -import { InitialSetupPanel } from './initial_setup_panel'; -import { IncorrectLicensePanel } from './incorrect_license_panel'; +import { StartedFrom } from '../../utils/get_timeline_items_from_conversation'; import { ChatHeader } from './chat_header'; import { ChatPromptEditor } from './chat_prompt_editor'; import { ChatTimeline } from './chat_timeline'; -import { StartedFrom } from '../../utils/get_timeline_items_from_conversation'; +import { ExperimentalFeatureBanner } from './experimental_feature_banner'; +import { IncorrectLicensePanel } from './incorrect_license_panel'; +import { InitialSetupPanel } from './initial_setup_panel'; +import { ChatActionClickType } from './types'; +import { EMPTY_CONVERSATION_TITLE } from '../../i18n'; const timelineClassName = css` overflow-y: auto; @@ -45,48 +49,45 @@ const incorrectLicenseContainer = css` padding: ${euiThemeVars.euiPanelPaddingModifiers.paddingMedium}; `; +const chatBodyContainerClassNameWithError = css` + align-self: center; +`; + export function ChatBody({ - title, - loading, - messages, + initialTitle, + initialMessages, + initialConversationId, connectors, knowledgeBase, connectorsManagementHref, modelsManagementHref, - conversationId, currentUser, startedFrom, - onChatUpdate, - onChatComplete, - onSaveTitle, + onConversationUpdate, }: { - title: string; - loading: boolean; - messages: Message[]; + initialTitle?: string; + initialMessages?: Message[]; + initialConversationId?: string; connectors: UseGenAIConnectorsResult; knowledgeBase: UseKnowledgeBaseResult; connectorsManagementHref: string; modelsManagementHref: string; - conversationId?: string; currentUser?: Pick; startedFrom?: StartedFrom; - onChatUpdate: (messages: Message[]) => void; - onChatComplete: (messages: Message[]) => void; - onSaveTitle: (title: string) => void; + onConversationUpdate: (conversation: Conversation) => void; }) { const license = useLicense(); const hasCorrectLicense = license?.hasAtLeast('enterprise'); const chatService = useObservabilityAIAssistantChatService(); - const timeline = useTimeline({ + const { conversation, messages, next, state, stop, saveTitle } = useConversation({ + initialConversationId, + initialMessages, + initialTitle, chatService, - connectors, - currentUser, - messages, - startedFrom, - onChatUpdate, - onChatComplete, + connectorId: connectors.selectedConnector, + onConversationUpdate, }); const timelineContainerRef = useRef(null); @@ -94,7 +95,10 @@ export function ChatBody({ let footer: React.ReactNode; const isLoading = Boolean( - connectors.loading || knowledgeBase.status.loading || last(flatten(timeline.items))?.loading + connectors.loading || + knowledgeBase.status.loading || + state === ChatState.Loading || + conversation.loading ); const containerClassName = css` @@ -139,12 +143,12 @@ export function ChatBody({ }); const handleCopyConversation = () => { - const content = JSON.stringify({ title, messages }); + const content = JSON.stringify({ title: initialTitle, messages }); navigator.clipboard?.writeText(content || ''); }; - if (!hasCorrectLicense && !conversationId) { + if (!hasCorrectLicense && !initialConversationId) { footer = ( <> @@ -155,19 +159,25 @@ export function ChatBody({ - + { + next(messages.concat(message)); + }} + /> ); - } else if (connectors.loading || knowledgeBase.status.loading) { + } else if (connectors.loading || knowledgeBase.status.loading || conversation.loading) { footer = ( ); - } else if (connectors.connectors?.length === 0 && !conversationId) { + } else if (connectors.connectors?.length === 0 && !initialConversationId) { footer = ( { + const indexOf = messages.indexOf(editedMessage); + next(messages.slice(0, indexOf).concat(newMessage)); + }} + onFeedback={(message, feedback) => {}} + onRegenerate={(message) => { + const indexOf = messages.indexOf(message); + next(messages.slice(0, indexOf)); + }} + onStopGenerating={() => { + stop(); + }} onActionClick={(payload) => { setStickToBottom(true); - return timeline.onActionClick(payload); + switch (payload.type) { + case ChatActionClickType.executeEsqlQuery: + next( + messages.concat({ + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + content: '', + function_call: { + name: 'execute_query', + arguments: JSON.stringify({ + query: payload.query, + }), + trigger: MessageRole.User, + }, + }, + }) + ); + break; + } }} /> @@ -207,7 +249,7 @@ export function ChatBody({ disabled={!connectors.selectedConnector || !hasCorrectLicense} onSubmit={(message) => { setStickToBottom(true); - return timeline.onSubmit(message); + return next(messages.concat(message)); }} /> @@ -224,20 +266,41 @@ export function ChatBody({ ) : null} - + + {conversation.error ? ( + + {i18n.translate('xpack.observabilityAiAssistant.couldNotFindConversationContent', { + defaultMessage: + 'Could not find a conversation with id {conversationId}. Make sure the conversation exists and you have access to it.', + values: { conversationId: initialConversationId }, + })} + + ) : null} + { + saveTitle(newTitle); + }} /> diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_consolidated_items.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_consolidated_items.tsx index 346ccfe501f37f..18a98c09ff3874 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_consolidated_items.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_consolidated_items.tsx @@ -120,12 +120,12 @@ export function ChatConsolidatedItems({ key={index} {...item} onFeedbackClick={(feedback) => { - onFeedback(item, feedback); + onFeedback(item.message, feedback); }} onRegenerateClick={() => { - onRegenerate(item); + onRegenerate(item.message); }} - onEditSubmit={(message) => onEditSubmit(item, message)} + onEditSubmit={(message) => onEditSubmit(item.message, message)} onStopGeneratingClick={onStopGenerating} onActionClick={onActionClick} /> diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.stories.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.stories.tsx index 30f56a6ab63a06..bf54e20c3ec506 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.stories.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.stories.tsx @@ -29,13 +29,10 @@ const Template: ComponentStory = (props: ChatFlyoutProps) => { const defaultProps: ChatFlyoutProps = { isOpen: true, - title: 'How is this working', - messages: [getAssistantSetupMessage({ contexts: [] })], + initialTitle: 'How is this working', + initialMessages: [getAssistantSetupMessage({ contexts: [] })], startedFrom: 'appTopNavbar', onClose: () => {}, - onChatComplete: () => {}, - onChatTitleSave: () => {}, - onChatUpdate: () => {}, }; export const ChatFlyout = Template.bind({}); diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.tsx index ef4635d873e8aa..ef4961b87b7c68 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_flyout.tsx @@ -7,7 +7,7 @@ import { EuiFlexGroup, EuiFlexItem, EuiFlyout, EuiLink, EuiPanel, useEuiTheme } from '@elastic/eui'; import { css } from '@emotion/css'; import { i18n } from '@kbn/i18n'; -import React from 'react'; +import React, { useState } from 'react'; import type { Message } from '../../../common/types'; import { useCurrentUser } from '../../hooks/use_current_user'; import { useGenAIConnectors } from '../../hooks/use_genai_connectors'; @@ -28,25 +28,17 @@ const bodyClassName = css` `; export function ChatFlyout({ - title, - messages, - conversationId, + initialTitle, + initialMessages, + onClose, isOpen, startedFrom, - onClose, - onChatUpdate, - onChatComplete, - onChatTitleSave, }: { - title: string; - messages: Message[]; - conversationId?: string; + initialTitle: string; + initialMessages: Message[]; isOpen: boolean; startedFrom: StartedFrom; onClose: () => void; - onChatUpdate: (messages: Message[]) => void; - onChatComplete: (messages: Message[]) => void; - onChatTitleSave: (title: string) => void; }) { const { euiTheme } = useEuiTheme(); const { @@ -61,6 +53,8 @@ export function ChatFlyout({ const knowledgeBase = useKnowledgeBase(); + const [conversationId, setConversationId] = useState(undefined); + return isOpen ? ( { - if (onChatUpdate) { - onChatUpdate(nextMessages); - } - }} - onChatComplete={(nextMessages) => { - if (onChatComplete) { - onChatComplete(nextMessages); - } - }} - onSaveTitle={(newTitle) => { - onChatTitleSave(newTitle); + onConversationUpdate={(conversation) => { + setConversationId(conversation.conversation.id); }} /> diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item.tsx index 7ec1084a26b22a..09cb52a1c30b97 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item.tsx @@ -27,7 +27,7 @@ import { FailedToLoadResponse } from '../message_panel/failed_to_load_response'; import { ChatActionClickHandler } from './types'; export interface ChatItemProps extends ChatTimelineItem { - onEditSubmit: (message: Message) => Promise; + onEditSubmit: (message: Message) => void; onFeedbackClick: (feedback: Feedback) => void; onRegenerateClick: () => void; onStopGeneratingClick: () => void; @@ -66,13 +66,14 @@ const noPanelMessageClassName = css` export function ChatItem({ actions: { canCopy, canEdit, canGiveFeedback, canRegenerate }, display: { collapsed }, + message: { + message: { function_call: functionCall, role }, + }, content, currentUser, element, error, - function_call: functionCall, loading, - role, title, onEditSubmit, onFeedbackClick, diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item_content_inline_prompt_editor.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item_content_inline_prompt_editor.tsx index df57f069d91d13..4f702eed2e16d1 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item_content_inline_prompt_editor.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_item_content_inline_prompt_editor.tsx @@ -22,7 +22,7 @@ interface Props { | undefined; loading: boolean; editing: boolean; - onSubmit: (message: Message) => Promise; + onSubmit: (message: Message) => void; onActionClick: ChatActionClickHandler; } export function ChatItemContentInlinePromptEditor({ diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_prompt_editor.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_prompt_editor.tsx index 21e9e3871205c8..f34288d5755e62 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_prompt_editor.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_prompt_editor.tsx @@ -5,20 +5,20 @@ * 2.0. */ -import React, { useCallback, useEffect, useRef, useState } from 'react'; import { EuiButtonEmpty, EuiButtonIcon, EuiFlexGroup, EuiFlexItem, + EuiFocusTrap, EuiPanel, EuiSpacer, EuiTextArea, keys, - EuiFocusTrap, } from '@elastic/eui'; import { i18n } from '@kbn/i18n'; import { CodeEditor } from '@kbn/kibana-react-plugin/public'; +import React, { useCallback, useEffect, useRef, useState } from 'react'; import { MessageRole, type Message } from '../../../common'; import { useJsonEditorModel } from '../../hooks/use_json_editor_model'; import { FunctionListPopover } from './function_list_popover'; @@ -30,7 +30,7 @@ export interface ChatPromptEditorProps { initialSelectedFunctionName?: string; initialFunctionPayload?: string; trigger?: MessageRole; - onSubmit: (message: Message) => Promise; + onSubmit: (message: Message) => void; } export function ChatPromptEditor({ @@ -216,7 +216,10 @@ export function ChatPromptEditor({ {selectedFunctionName ? ( 8 ? '200px' : '120px'} @@ -284,7 +287,10 @@ export function ChatPromptEditor({ = (props: ChatTimelineProps) => { - const [count, setCount] = useState(props.items.length - 1); + const [count, setCount] = useState(props.messages.length - 1); return ( <> - index <= count)} /> + index <= count)} /> setCount(count >= 0 && count < props.items.length - 1 ? count + 1 : 0)} + onClick={() => setCount(count >= 0 && count < props.messages.length - 1 ? count + 1 : 0)} > Add message @@ -61,13 +63,23 @@ const defaultProps: ComponentProps = { installError: undefined, install: async () => {}, }, - items: [ - buildChatInitItem(), - buildUserChatItem(), - buildAssistantChatItem(), - buildUserChatItem({ content: 'How does it work?' }), - buildAssistantChatItem({ - content: `The way functions work depends on whether we are talking about mathematical functions or programming functions. Let's explore both: + chatService: { + hasRenderFunction: () => false, + } as unknown as ObservabilityAIAssistantChatService, + chatState: ChatState.Ready, + hasConnector: true, + currentUser: { + full_name: 'John Doe', + username: 'johndoe', + }, + messages: [ + buildSystemMessage(), + buildUserMessage(), + buildAssistantMessage(), + buildUserMessage({ message: { content: 'How does it work?' } }), + buildAssistantMessage({ + message: { + content: `The way functions work depends on whether we are talking about mathematical functions or programming functions. Let's explore both: Mathematical Functions: In mathematics, a function maps input values to corresponding output values based on a specific rule or expression. The general process of how a mathematical function works can be summarized as follows: @@ -78,55 +90,34 @@ const defaultProps: ComponentProps = { Step 3: Output - After processing the input, the function produces an output value, denoted as 'f(x)' or 'y'. This output represents the dependent variable and is the result of applying the function's rule to the input. Step 4: Uniqueness - A well-defined mathematical function ensures that each input value corresponds to exactly one output value. In other words, the function should yield the same output for the same input whenever it is called.`, + }, }), - buildUserChatItem({ - content: 'Can you execute a function?', + buildUserMessage({ + message: { content: 'Can you execute a function?' }, }), - buildAssistantChatItem({ - content: 'Sure, I can do that.', - title: 'suggested a function', - function_call: { - name: 'a_function', - arguments: '{ "foo": "bar" }', - trigger: MessageRole.Assistant, - }, - actions: { - canEdit: false, - canCopy: true, - canGiveFeedback: true, - canRegenerate: true, + buildAssistantMessage({ + message: { + content: 'Sure, I can do that.', + function_call: { + name: 'a_function', + arguments: '{ "foo": "bar" }', + trigger: MessageRole.Assistant, + }, }, }), - buildFunctionChatItem({ - content: '{ "message": "The arguments are wrong" }', - error: new Error(), - actions: { - canRegenerate: false, - canEdit: true, - canGiveFeedback: false, - canCopy: true, - }, + buildFunctionResponseMessage({ + message: { content: '{ "message": "The arguments are wrong" }' }, }), - buildAssistantChatItem({ - content: '', - title: 'suggested a function', - function_call: { - name: 'a_function', - arguments: '{ "bar": "foo" }', - trigger: MessageRole.Assistant, - }, - actions: { - canEdit: true, - canCopy: true, - canGiveFeedback: true, - canRegenerate: true, + buildAssistantMessage({ + message: { + content: '', + function_call: { + name: 'a_function', + arguments: '{ "bar": "foo" }', + trigger: MessageRole.Assistant, + }, }, }), - buildFunctionChatItem({ - content: '', - title: 'are executing a function', - loading: true, - }), ], onEdit: async () => {}, onFeedback: () => {}, diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx index e42924e7656097..065208695349d1 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx @@ -5,7 +5,7 @@ * 2.0. */ -import React, { ReactNode } from 'react'; +import React, { ReactNode, useMemo } from 'react'; import { css } from '@emotion/css'; import { EuiCommentList } from '@elastic/eui'; import type { AuthenticatedUser } from '@kbn/security-plugin/common'; @@ -16,6 +16,12 @@ import type { Feedback } from '../feedback_buttons'; import { type Message } from '../../../common'; import type { UseKnowledgeBaseResult } from '../../hooks/use_knowledge_base'; import type { ChatActionClickHandler } from './types'; +import { + getTimelineItemsfromConversation, + StartedFrom, +} from '../../utils/get_timeline_items_from_conversation'; +import { ObservabilityAIAssistantChatService } from '../../types'; +import { ChatState } from '../../hooks/use_chat'; export interface ChatTimelineItem extends Pick { @@ -35,27 +41,70 @@ export interface ChatTimelineItem element?: React.ReactNode; currentUser?: Pick; error?: any; + message: Message; } export interface ChatTimelineProps { - items: Array; + messages: Message[]; knowledgeBase: UseKnowledgeBaseResult; - onEdit: (item: ChatTimelineItem, message: Message) => Promise; - onFeedback: (item: ChatTimelineItem, feedback: Feedback) => void; - onRegenerate: (item: ChatTimelineItem) => void; + chatService: ObservabilityAIAssistantChatService; + hasConnector: boolean; + chatState: ChatState; + currentUser?: Pick; + startedFrom?: StartedFrom; + onEdit: (message: Message, messageAfterEdit: Message) => void; + onFeedback: (message: Message, feedback: Feedback) => void; + onRegenerate: (message: Message) => void; onStopGenerating: () => void; onActionClick: ChatActionClickHandler; } export function ChatTimeline({ - items = [], + messages, knowledgeBase, + chatService, + hasConnector, + currentUser, + startedFrom, onEdit, onFeedback, onRegenerate, onStopGenerating, onActionClick, + chatState, }: ChatTimelineProps) { + const items = useMemo(() => { + const timelineItems = getTimelineItemsfromConversation({ + chatService, + hasConnector, + messages, + currentUser, + startedFrom, + chatState, + }); + + const consolidatedChatItems: Array = []; + let currentGroup: ChatTimelineItem[] | null = null; + + for (const item of timelineItems) { + if (item.display.hide || !item) continue; + + if (item.display.collapsed) { + if (currentGroup) { + currentGroup.push(item); + } else { + currentGroup = [item]; + consolidatedChatItems.push(currentGroup); + } + } else { + consolidatedChatItems.push(item); + currentGroup = null; + } + } + + return consolidatedChatItems; + }, [chatService, hasConnector, messages, currentUser, startedFrom, chatState]); + return ( ) : ( - items.map((item, index) => - Array.isArray(item) ? ( + items.map((item, index) => { + return Array.isArray(item) ? ( { - onFeedback(item, feedback); + onFeedback(item.message, feedback); }} onRegenerateClick={() => { - onRegenerate(item); + onRegenerate(item.message); + }} + onEditSubmit={(message) => { + onEdit(item.message, message); }} - onEditSubmit={(message) => onEdit(item, message)} onStopGeneratingClick={onStopGenerating} onActionClick={onActionClick} /> - ) - ) + ); + }) )} ); diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/types.ts b/x-pack/plugins/observability_ai_assistant/public/components/chat/types.ts index 4edd3d7dcdda08..017f2f81a6f63e 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/types.ts +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/types.ts @@ -20,4 +20,4 @@ export enum ChatActionClickType { executeEsqlQuery = 'executeEsqlQuery', } -export type ChatActionClickHandler = (payload: ChatActionClickPayload) => Promise; +export type ChatActionClickHandler = (payload: ChatActionClickPayload) => void; diff --git a/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx b/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx index 8f9c477c5d4ef8..48ba86a98fbe8f 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx @@ -4,20 +4,17 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'; -import { first } from 'lodash'; import { EuiFlexGroup, EuiFlexItem } from '@elastic/eui'; -import { AbortError } from '@kbn/kibana-utils-plugin/common'; -import { isObservable, Subscription } from 'rxjs'; +import { last } from 'lodash'; +import React, { useEffect, useRef, useState } from 'react'; import { MessageRole, type Message } from '../../../common/types'; import { ObservabilityAIAssistantChatServiceProvider } from '../../context/observability_ai_assistant_chat_service_provider'; -import { useKibana } from '../../hooks/use_kibana'; import { useAbortableAsync } from '../../hooks/use_abortable_async'; -import { useConversation } from '../../hooks/use_conversation'; +import { ChatState, useChat } from '../../hooks/use_chat'; import { useGenAIConnectors } from '../../hooks/use_genai_connectors'; +import { useKibana } from '../../hooks/use_kibana'; import { useObservabilityAIAssistant } from '../../hooks/use_observability_ai_assistant'; import { useObservabilityAIAssistantChatService } from '../../hooks/use_observability_ai_assistant_chat_service'; -import type { PendingMessage } from '../../types'; import { getConnectorsManagementHref } from '../../utils/get_connectors_management_href'; import { RegenerateResponseButton } from '../buttons/regenerate_response_button'; import { StartChatButton } from '../buttons/start_chat_button'; @@ -40,197 +37,40 @@ function ChatContent({ }) { const chatService = useObservabilityAIAssistantChatService(); - const [pendingMessage, setPendingMessage] = useState(); - - const [loading, setLoading] = useState(false); - const [subscription, setSubscription] = useState(); + const initialMessagesRef = useRef(initialMessages); - const [conversationId, setConversationId] = useState(); - - const { - conversation, - displayedMessages, - setDisplayedMessages, - getSystemMessage, - save, - saveTitle, - } = useConversation({ - conversationId, - connectorId, + const { messages, next, state, stop } = useChat({ chatService, + connectorId, initialMessages, }); - const conversationTitle = conversationId - ? conversation.value?.conversation.title || '' - : defaultTitle; - - const controllerRef = useRef(new AbortController()); - - const reloadRecalledMessages = useCallback( - async (messages: Message[]) => { - controllerRef.current.abort(); - - const controller = (controllerRef.current = new AbortController()); - - const isStartOfConversation = - messages.some((message) => message.message.role === MessageRole.Assistant) === false; - - if (isStartOfConversation && chatService.hasFunction('recall')) { - // manually execute recall function and append to list of - // messages - const functionCall = { - name: 'recall', - args: JSON.stringify({ queries: [], contexts: [] }), - }; - - const response = await chatService.executeFunction({ - ...functionCall, - messages, - signal: controller.signal, - connectorId, - }); - - if (isObservable(response)) { - throw new Error('Recall function unexpectedly returned an Observable'); - } - - return [ - { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.Assistant, - content: '', - function_call: { - name: functionCall.name, - arguments: functionCall.args, - trigger: MessageRole.User as const, - }, - }, - }, - { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.User, - name: functionCall.name, - content: JSON.stringify(response.content), - }, - }, - ]; - } - - return []; - }, - [chatService, connectorId] + const lastAssistantResponse = last( + messages.filter((message) => message.message.role === MessageRole.Assistant) ); - const reloadConversation = useCallback(async () => { - setLoading(true); - - setDisplayedMessages(initialMessages); - setPendingMessage(undefined); - - const messages = [getSystemMessage(), ...initialMessages]; - - const recalledMessages = await reloadRecalledMessages(messages); - const next = messages.concat(recalledMessages); - - setDisplayedMessages(next); - - let lastPendingMessage: PendingMessage | undefined; - - const nextSubscription = chatService - .chat({ messages: next, connectorId, function: 'none' }) - .subscribe({ - next: (msg) => { - lastPendingMessage = msg; - setPendingMessage(() => msg); - }, - complete: () => { - setDisplayedMessages((prev) => - prev.concat({ - '@timestamp': new Date().toISOString(), - ...lastPendingMessage!, - }) - ); - setPendingMessage(lastPendingMessage); - setLoading(false); - }, - }); - - setSubscription(nextSubscription); - }, [ - reloadRecalledMessages, - chatService, - connectorId, - initialMessages, - getSystemMessage, - setDisplayedMessages, - ]); - useEffect(() => { - reloadConversation(); - }, [reloadConversation]); - - useEffect(() => { - setDisplayedMessages(initialMessages); - }, [initialMessages, setDisplayedMessages]); + next(initialMessagesRef.current); + }, [next]); const [isOpen, setIsOpen] = useState(false); - const messagesWithPending = useMemo(() => { - return pendingMessage - ? displayedMessages.concat({ - '@timestamp': new Date().toISOString(), - message: { - ...pendingMessage.message, - }, - }) - : displayedMessages; - }, [pendingMessage, displayedMessages]); - - const firstAssistantMessage = first( - messagesWithPending.filter( - (message) => - message.message.role === MessageRole.Assistant && - (!message.message.function_call?.trigger || - message.message.function_call.trigger === MessageRole.Assistant) - ) - ); - return ( <> {}} /> } - error={pendingMessage?.error} + error={state === ChatState.Error} controls={ - loading ? ( + state === ChatState.Loading ? ( { - subscription?.unsubscribe(); - setLoading(false); - setDisplayedMessages((prevMessages) => - prevMessages.concat({ - '@timestamp': new Date().toISOString(), - message: { - ...pendingMessage!.message, - }, - }) - ); - setPendingMessage((prev) => ({ - message: { - role: MessageRole.Assistant, - ...prev?.message, - }, - aborted: true, - error: new AbortError(), - })); + stop(); }} /> ) : ( @@ -238,7 +78,7 @@ function ChatContent({ { - reloadConversation(); + next(initialMessages); }} /> @@ -254,27 +94,13 @@ function ChatContent({ } /> { - setIsOpen(() => false); + setIsOpen(false); }} - messages={displayedMessages} - conversationId={conversationId} + initialMessages={messages} + initialTitle={defaultTitle} startedFrom="contextualInsight" - onChatComplete={(nextMessages) => { - save(nextMessages) - .then((nextConversation) => { - setConversationId(nextConversation.conversation.id); - }) - .catch(() => {}); - }} - onChatUpdate={(nextMessages) => { - setDisplayedMessages(nextMessages); - }} - onChatTitleSave={(newTitle) => { - saveTitle(newTitle); - }} /> ); diff --git a/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_panel.stories.tsx b/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_panel.stories.tsx index 393bbeee28f8da..9af3e3cf285b40 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_panel.stories.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_panel.stories.tsx @@ -71,7 +71,7 @@ export const ContentFailed: ComponentStoryObj = { onActionClick={async () => {}} /> ), - error: new Error(), + error: true, }, }; @@ -111,7 +111,7 @@ export const Controls: ComponentStoryObj = { onActionClick={async () => {}} /> ), - error: new Error(), + error: true, controls: {}} />, }, }; diff --git a/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_panel.tsx b/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_panel.tsx index 78a1e8fae47788..820ab2a55c271b 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_panel.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/message_panel/message_panel.tsx @@ -9,7 +9,7 @@ import React from 'react'; import { FailedToLoadResponse } from './failed_to_load_response'; interface Props { - error?: Error; + error?: boolean; body?: React.ReactNode; controls?: React.ReactNode; } diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_abortable_async.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_abortable_async.ts index a37624d441757f..afd776dc139909 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_abortable_async.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_abortable_async.ts @@ -18,9 +18,9 @@ export type AbortableAsyncState = (T extends Promise : State) & { refresh: () => void }; export function useAbortableAsync( - fn: ({}: { signal: AbortSignal }) => T, + fn: ({}: { signal: AbortSignal }) => T | Promise, deps: any[], - options?: { clearValueOnNext?: boolean } + options?: { clearValueOnNext?: boolean; defaultValue?: () => T } ): AbortableAsyncState { const clearValueOnNext = options?.clearValueOnNext; @@ -30,7 +30,7 @@ export function useAbortableAsync( const [error, setError] = useState(); const [loading, setLoading] = useState(false); - const [value, setValue] = useState(); + const [value, setValue] = useState(options?.defaultValue); useEffect(() => { controllerRef.current.abort(); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts index 70b97a601644cc..22bc997a4c925b 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts @@ -4,55 +4,461 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import { useChat } from './use_chat'; +import type { DeeplyMockedKeys } from '@kbn/utility-types-jest'; +import { type RenderHookResult, renderHook, act } from '@testing-library/react-hooks'; +import { Subject } from 'rxjs'; +import { MessageRole } from '../../common'; +import type { ObservabilityAIAssistantChatService, PendingMessage } from '../types'; +import { type UseChatResult, useChat, type UseChatProps, ChatState } from './use_chat'; +import * as useKibanaModule from './use_kibana'; + +type MockedChatService = DeeplyMockedKeys; + +const mockChatService: MockedChatService = { + chat: jest.fn(), + executeFunction: jest.fn(), + getContexts: jest.fn().mockReturnValue([{ name: 'core', description: '' }]), + getFunctions: jest.fn().mockReturnValue([]), + hasFunction: jest.fn().mockReturnValue(false), + hasRenderFunction: jest.fn().mockReturnValue(true), + renderFunction: jest.fn(), +}; + +const addErrorMock = jest.fn(); + +jest.spyOn(useKibanaModule, 'useKibana').mockReturnValue({ + services: { + notifications: { + toasts: { + addError: addErrorMock, + }, + }, + }, +} as any); + +let hookResult: RenderHookResult; describe('useChat', () => { - // Tests for initial hook setup and default states + beforeEach(() => { + jest.clearAllMocks(); + }); describe('initially', () => { - it('returns the initial messages including the system message', () => {}); - it('sets chatState to ready', () => {}); + beforeEach(() => { + hookResult = renderHook(useChat, { + initialProps: { + connectorId: 'my-connector', + chatService: mockChatService, + initialMessages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'hello', + }, + }, + ], + } as UseChatProps, + }); + }); + + it('returns the initial messages including the system message', () => { + const { messages } = hookResult.result.current; + expect(messages.length).toBe(2); + expect(messages[0].message.role).toBe('system'); + expect(messages[1].message.content).toBe('hello'); + }); + + it('sets chatState to ready', () => { + expect(hookResult.result.current.state).toBe(ChatState.Ready); + }); }); describe('when calling next()', () => { - it('sets the chatState to loading', () => {}); + let subject: Subject; - describe('after a partial response it updates the returned messages', () => {}); + beforeEach(() => { + hookResult = renderHook(useChat, { + initialProps: { + connectorId: 'my-connector', + chatService: mockChatService, + initialMessages: [], + } as UseChatProps, + }); - describe('after a completed response it updates the returned messages and the loading state', () => {}); + subject = new Subject(); - describe('after aborting a response it shows the partial message and sets chatState to aborted', () => {}); + mockChatService.chat.mockReturnValueOnce(subject); - describe('after a response errors out, it shows the partial message and sets chatState to error', () => {}); - }); + act(() => { + hookResult.result.current.next([ + ...hookResult.result.current.messages, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'hello', + }, + }, + ]); + }); + }); - // Tests for the 'next' function behavior - describe('Function next', () => { - it('should handle empty message array correctly', () => {}); - it('should ignore non-user and non-assistant messages with function requests', () => {}); - it('should set chat state to loading on valid next message', () => {}); - it('should handle abort signal correctly during message processing', () => {}); - it('should handle message processing for assistant messages with function request', () => {}); - it('should handle message processing for user messages', () => {}); - it('should handle observable responses correctly', () => {}); - it('should update messages correctly after response', () => {}); - it('should handle errors during message processing', () => {}); - }); + it('sets the chatState to loading', () => { + expect(hookResult.result.current.state).toBe(ChatState.Loading); + }); - // Tests for the 'stop' function behavior - describe('Function stop', () => { - it('should abort current operation when stop is called', () => {}); - }); + describe('after asking for another response', () => { + beforeEach(() => { + act(() => { + hookResult.result.current.next([]); + subject.next({ + message: { + role: MessageRole.User, + content: 'goodbye', + }, + }); + subject.complete(); + }); + }); + + it('shows an empty list of messages', () => { + expect(hookResult.result.current.messages.length).toBe(1); + expect(hookResult.result.current.messages[0].message.role).toBe(MessageRole.System); + }); + + it('aborts the running request', () => { + expect(subject.observed).toBe(false); + }); + }); + + describe('after a partial response', () => { + it('updates the returned messages', () => { + act(() => { + subject.next({ + message: { + content: 'good', + role: MessageRole.Assistant, + }, + }); + }); + + expect(hookResult.result.current.messages[2].message.content).toBe('good'); + }); + }); + + describe('after a completed response', () => { + it('updates the returned messages and the loading state', () => { + act(() => { + subject.next({ + message: { + content: 'good', + role: MessageRole.Assistant, + }, + }); + subject.next({ + message: { + content: 'goodbye', + role: MessageRole.Assistant, + }, + }); + subject.complete(); + }); + + expect(hookResult.result.current.messages[2].message.content).toBe('goodbye'); + expect(hookResult.result.current.state).toBe(ChatState.Ready); + }); + }); + + describe('after aborting a response', () => { + beforeEach(() => { + act(() => { + subject.next({ + message: { + content: 'good', + role: MessageRole.Assistant, + }, + aborted: true, + }); + subject.complete(); + }); + }); + + it('shows the partial message and sets chatState to aborted', () => { + expect(hookResult.result.current.messages[2].message.content).toBe('good'); + expect(hookResult.result.current.state).toBe(ChatState.Aborted); + }); + + it('does not show an error toast', () => { + expect(addErrorMock).not.toHaveBeenCalled(); + }); + }); + + describe('after a response errors out', () => { + beforeEach(() => { + act(() => { + subject.next({ + message: { + content: 'good', + role: MessageRole.Assistant, + }, + error: new Error('foo'), + }); + subject.complete(); + }); + }); + + it('shows the partial message and sets chatState to error', () => { + expect(hookResult.result.current.messages[2].message.content).toBe('good'); + expect(hookResult.result.current.state).toBe(ChatState.Error); + }); + + it('shows an error toast', () => { + expect(addErrorMock).toHaveBeenCalled(); + }); + }); + + describe('after the LLM responds with a function call', () => { + let resolve: (data: any) => void; + let reject: (error: Error) => void; + + beforeEach(() => { + mockChatService.executeFunction.mockResolvedValueOnce( + new Promise((...args) => { + resolve = args[0]; + reject = args[1]; + }) + ); + + act(() => { + subject.next({ + message: { + content: '', + role: MessageRole.Assistant, + function_call: { + name: 'my_function', + arguments: JSON.stringify({ foo: 'bar' }), + trigger: MessageRole.Assistant, + }, + }, + }); + subject.complete(); + }); + }); + + it('the chat state stays loading', () => { + expect(hookResult.result.current.state).toBe(ChatState.Loading); + }); + + it('adds a message', () => { + const { messages } = hookResult.result.current; + + expect(messages.length).toBe(3); + expect(messages[2]).toEqual({ + '@timestamp': expect.any(String), + message: { + content: '', + function_call: { + arguments: JSON.stringify({ foo: 'bar' }), + name: 'my_function', + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }); + }); - // Tests for the state management within the hook - describe('State management', () => { - it('should update chat state correctly', () => {}); - it('should update messages state correctly', () => {}); - it('should handle pending message state correctly', () => {}); + describe('the function call succeeds', () => { + beforeEach(async () => { + subject = new Subject(); + mockChatService.chat.mockReturnValueOnce(subject); + + await act(async () => { + resolve({ content: { foo: 'bar' }, data: { bar: 'foo' } }); + }); + }); + + it('adds a message', () => { + const { messages } = hookResult.result.current; + + expect(messages.length).toBe(4); + expect(messages[3]).toEqual({ + '@timestamp': expect.any(String), + message: { + content: JSON.stringify({ foo: 'bar' }), + data: JSON.stringify({ bar: 'foo' }), + name: 'my_function', + role: MessageRole.User, + }, + }); + }); + + it('keeps the chat state in loading', () => { + expect(hookResult.result.current.state).toBe(ChatState.Loading); + }); + it('sends the function call back to the LLM for a response', () => { + expect(mockChatService.chat).toHaveBeenCalledTimes(2); + expect(mockChatService.chat).toHaveBeenLastCalledWith({ + connectorId: 'my-connector', + messages: hookResult.result.current.messages, + }); + }); + }); + + describe('the function call fails', () => { + beforeEach(async () => { + subject = new Subject(); + mockChatService.chat.mockReturnValue(subject); + + await act(async () => { + reject(new Error('connection error')); + }); + }); + + it('keeps the chat state in loading', () => { + expect(hookResult.result.current.state).toBe(ChatState.Loading); + }); + + it('adds a message', () => { + const { messages } = hookResult.result.current; + + expect(messages.length).toBe(4); + expect(messages[3]).toEqual({ + '@timestamp': expect.any(String), + message: { + content: JSON.stringify({ + message: 'Error: connection error', + error: {}, + }), + name: 'my_function', + role: MessageRole.User, + }, + }); + }); + + it('does not show an error toast', () => { + expect(addErrorMock).not.toHaveBeenCalled(); + }); + + it('sends the function call back to the LLM for a response', () => { + expect(mockChatService.chat).toHaveBeenCalledTimes(2); + expect(mockChatService.chat).toHaveBeenLastCalledWith({ + connectorId: 'my-connector', + messages: hookResult.result.current.messages, + }); + }); + }); + + describe('stop() is called', () => { + beforeEach(() => { + act(() => { + hookResult.result.current.stop(); + }); + }); + + it('sets the chatState to aborted', () => { + expect(hookResult.result.current.state).toBe(ChatState.Aborted); + }); + + it('has called the abort controller', () => { + const signal = mockChatService.executeFunction.mock.calls[0][0].signal; + + expect(signal.aborted).toBe(true); + }); + + it('is not updated after the promise is rejected', () => { + const numRenders = hookResult.result.all.length; + + act(() => { + reject(new Error('Request aborted')); + }); + + expect(numRenders).toBe(hookResult.result.all.length); + }); + + it('removes all subscribers', () => { + expect(subject.observed).toBe(false); + }); + }); + + describe('setMessages() is called', () => {}); + }); }); - // Tests for cleanup and unmounting behavior - describe('Cleanup and unmounting behavior', () => { - it('should abort any ongoing process on unmount', () => {}); + describe('when calling next() with the recall function available', () => { + let subject: Subject; + + beforeEach(async () => { + hookResult = renderHook(useChat, { + initialProps: { + connectorId: 'my-connector', + chatService: mockChatService, + initialMessages: [], + } as UseChatProps, + }); + + subject = new Subject(); + + mockChatService.hasFunction.mockReturnValue(true); + mockChatService.executeFunction.mockResolvedValueOnce({ + content: [ + { + id: 'my_document', + text: 'My text', + }, + ], + }); + + mockChatService.chat.mockReturnValueOnce(subject); + + await act(async () => { + hookResult.result.current.next([ + ...hookResult.result.current.messages, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'hello', + }, + }, + ]); + }); + }); + + it('adds a user message and a recall function request', () => { + expect(hookResult.result.current.messages[1].message.content).toBe('hello'); + expect(hookResult.result.current.messages[2].message.function_call?.name).toBe('recall'); + expect(hookResult.result.current.messages[2].message.content).toBe(''); + expect(hookResult.result.current.messages[2].message.function_call?.arguments).toBe( + JSON.stringify({ queries: [], contexts: [] }) + ); + expect(hookResult.result.current.messages[3].message.name).toBe('recall'); + expect(hookResult.result.current.messages[3].message.content).toBe( + JSON.stringify([ + { + id: 'my_document', + text: 'My text', + }, + ]) + ); + }); + + it('executes the recall function', () => { + expect(mockChatService.executeFunction).toHaveBeenCalled(); + expect(mockChatService.executeFunction).toHaveBeenCalledWith({ + signal: expect.any(AbortSignal), + connectorId: 'my-connector', + args: JSON.stringify({ queries: [], contexts: [] }), + name: 'recall', + messages: [...hookResult.result.current.messages.slice(0, -1)], + }); + }); + + it('sends the user message, function request and recall response to the LLM', () => { + expect(mockChatService.chat).toHaveBeenCalledWith({ + connectorId: 'my-connector', + messages: [...hookResult.result.current.messages], + }); + }); }); }); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts index 4805824668344f..aeef36127f6c4b 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts @@ -13,6 +13,7 @@ import { type Message, MessageRole } from '../../common'; import { getAssistantSetupMessage } from '../service/get_assistant_setup_message'; import type { ObservabilityAIAssistantChatService, PendingMessage } from '../types'; import { useKibana } from './use_kibana'; +import { useOnce } from './use_once'; export enum ChatState { Ready = 'ready', @@ -21,24 +22,35 @@ export enum ChatState { Aborted = 'aborted', } -interface UseChatResult { +export interface UseChatResult { messages: Message[]; + setMessages: (messages: Message[]) => void; state: ChatState; next: (messages: Message[]) => void; stop: () => void; } +export interface UseChatProps { + initialMessages: Message[]; + chatService: ObservabilityAIAssistantChatService; + connectorId?: string; + onChatComplete?: (messages: Message[]) => void; +} + export function useChat({ initialMessages, chatService, connectorId, -}: { - initialMessages: Message[]; - chatService: ObservabilityAIAssistantChatService; - connectorId: string; -}): UseChatResult { + onChatComplete, +}: UseChatProps): UseChatResult { const [chatState, setChatState] = useState(ChatState.Ready); + const systemMessage = useMemo(() => { + return getAssistantSetupMessage({ contexts: chatService.getContexts() }); + }, [chatService]); + + useOnce(initialMessages); + const [messages, setMessages] = useState(initialMessages); const [pendingMessage, setPendingMessage] = useState(); @@ -49,116 +61,165 @@ export function useChat({ services: { notifications }, } = useKibana(); + const onChatCompleteRef = useRef(onChatComplete); + + onChatCompleteRef.current = onChatComplete; + const handleSignalAbort = useCallback(() => { setChatState(ChatState.Aborted); }, []); - async function next(nextMessages: Message[]) { - abortControllerRef.current.signal.removeEventListener('abort', handleSignalAbort); + const next = useCallback( + async (nextMessages: Message[]) => { + // make sure we ignore any aborts for the previous signal + abortControllerRef.current.signal.removeEventListener('abort', handleSignalAbort); - const lastMessage = last(nextMessages); + // cancel running requests + abortControllerRef.current.abort(); - if (!lastMessage) { - setChatState(ChatState.Ready); - return; - } + const lastMessage = last(nextMessages); - const isUserMessage = lastMessage.message.role === MessageRole.User; - const functionCall = lastMessage.message.function_call; - const isAssistantMessageWithFunctionRequest = - lastMessage.message.role === MessageRole.Assistant && functionCall && !!functionCall?.name; + const allMessages = [ + systemMessage, + ...nextMessages.filter((message) => message.message.role !== MessageRole.System), + ]; - if (!isUserMessage && !isAssistantMessageWithFunctionRequest) { - setChatState(ChatState.Ready); - return; - } + setMessages(allMessages); - const abortController = (abortControllerRef.current = new AbortController()); + if (!lastMessage || !connectorId) { + setChatState(ChatState.Ready); + onChatCompleteRef.current?.(nextMessages); + return; + } - abortController.signal.addEventListener('abort', handleSignalAbort); + const isUserMessage = lastMessage.message.role === MessageRole.User; + const functionCall = lastMessage.message.function_call; + const isAssistantMessageWithFunctionRequest = + lastMessage.message.role === MessageRole.Assistant && functionCall && !!functionCall.name; - setChatState(ChatState.Loading); + const isFunctionResult = isUserMessage && !!lastMessage.message.name; - const allMessages = [ - getAssistantSetupMessage({ contexts: chatService.getContexts() }), - ...nextMessages.filter((message) => message.message.role !== MessageRole.System), - ]; + const isRecallFunctionAvailable = chatService.hasFunction('recall'); - function handleError(error: Error) { - notifications.toasts.addError(error, { - title: i18n.translate('xpack.observabilityAiAssistant.failedToLoadResponse', { - defaultMessage: 'Failed to load response from the AI Assistant', - }), - }); - } - - const response = isAssistantMessageWithFunctionRequest - ? await chatService - .executeFunction({ - name: functionCall.name, - signal: abortController.signal, - args: functionCall.arguments, - connectorId, - messages: allMessages, - }) - .catch((error) => { - return { - content: JSON.stringify({ - message: error.toString(), - error, - }), - data: undefined, - }; - }) - : chatService.chat({ - messages: allMessages, - connectorId, - }); + if (!isUserMessage && !isAssistantMessageWithFunctionRequest) { + setChatState(ChatState.Ready); + onChatCompleteRef.current?.(nextMessages); + return; + } + + const abortController = (abortControllerRef.current = new AbortController()); + + abortController.signal.addEventListener('abort', handleSignalAbort); + + setChatState(ChatState.Loading); - if (isObservable(response)) { - const localPendingMessage = pendingMessage!; - const subscription = response.subscribe({ - next: (nextPendingMessage) => { - setPendingMessage(nextPendingMessage); - }, - complete: () => { - setPendingMessage(undefined); - const allMessagesWithResolved = allMessages.concat({ - message: { - ...localPendingMessage.message, + if (isUserMessage && !isFunctionResult && isRecallFunctionAvailable) { + const allMessagesWithRecall = allMessages.concat({ + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + content: '', + function_call: { + name: 'recall', + arguments: JSON.stringify({ queries: [], contexts: [] }), + trigger: MessageRole.Assistant, }, - '@timestamp': new Date().toISOString(), + }, + }); + next(allMessagesWithRecall); + return; + } + + function handleError(error: Error) { + setChatState(ChatState.Error); + notifications.toasts.addError(error, { + title: i18n.translate('xpack.observabilityAiAssistant.failedToLoadResponse', { + defaultMessage: 'Failed to load response from the AI Assistant', + }), + }); + } + + const response = isAssistantMessageWithFunctionRequest + ? await chatService + .executeFunction({ + name: functionCall.name, + signal: abortController.signal, + args: functionCall.arguments, + connectorId, + messages: allMessages, + }) + .catch((error) => { + return { + content: { + message: error.toString(), + error, + }, + data: undefined, + }; + }) + : chatService.chat({ + messages: allMessages, + connectorId, }); - setMessages(allMessagesWithResolved); - if (localPendingMessage.aborted) { - setChatState(ChatState.Aborted); - } else if (localPendingMessage.error) { - handleError(localPendingMessage.error); - } else { - next(allMessagesWithResolved); - } - }, - error: (error) => { - handleError(error); - }, - }); - - abortController.signal.addEventListener('abort', () => { - subscription.unsubscribe(); - }); - } else { - const allMessagesWithFunctionReply = allMessages.concat({ - '@timestamp': new Date().toISOString(), - message: { - name: functionCall!.name, - role: MessageRole.User, - content: JSON.stringify(response.content), - data: JSON.stringify(response.data), - }, - }); - next(allMessagesWithFunctionReply); - } - } + + if (abortController.signal.aborted) { + return; + } + + if (isObservable(response)) { + let localPendingMessage: PendingMessage = { + message: { + content: '', + role: MessageRole.User, + }, + }; + + const subscription = response.subscribe({ + next: (nextPendingMessage) => { + localPendingMessage = nextPendingMessage; + setPendingMessage(nextPendingMessage); + }, + complete: () => { + setPendingMessage(undefined); + const allMessagesWithResolved = allMessages.concat({ + message: { + ...localPendingMessage.message, + }, + '@timestamp': new Date().toISOString(), + }); + if (localPendingMessage.aborted) { + setChatState(ChatState.Aborted); + setMessages(allMessagesWithResolved); + } else if (localPendingMessage.error) { + handleError(localPendingMessage.error); + setMessages(allMessagesWithResolved); + } else { + next(allMessagesWithResolved); + } + }, + error: (error) => { + handleError(error); + }, + }); + + abortController.signal.addEventListener('abort', () => { + subscription.unsubscribe(); + }); + } else { + const allMessagesWithFunctionReply = allMessages.concat({ + '@timestamp': new Date().toISOString(), + message: { + name: functionCall!.name, + role: MessageRole.User, + content: JSON.stringify(response.content), + data: JSON.stringify(response.data), + }, + }); + next(allMessagesWithFunctionReply); + } + }, + [connectorId, chatService, handleSignalAbort, notifications.toasts, systemMessage] + ); useEffect(() => { return () => { @@ -167,13 +228,29 @@ export function useChat({ }, []); const memoizedMessages = useMemo(() => { + const includingSystemMessage = [ + systemMessage, + ...messages.filter((message) => message.message.role !== MessageRole.System), + ]; + return pendingMessage - ? messages.concat({ ...pendingMessage, '@timestamp': new Date().toISOString() }) - : messages; - }, [messages, pendingMessage]); + ? includingSystemMessage.concat({ + ...pendingMessage, + '@timestamp': new Date().toISOString(), + }) + : includingSystemMessage; + }, [systemMessage, messages, pendingMessage]); + + const setMessagesWithAbort = useCallback((nextMessages: Message[]) => { + abortControllerRef.current.abort(); + setPendingMessage(undefined); + setChatState(ChatState.Ready); + setMessages(nextMessages); + }, []); return { messages: memoizedMessages, + setMessages: setMessagesWithAbort, state: chatState, next, stop: () => { diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx new file mode 100644 index 00000000000000..b9e837962ee1ba --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx @@ -0,0 +1,536 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import React from 'react'; +import { + useConversation, + type UseConversationProps, + type UseConversationResult, +} from './use_conversation'; +import { + act, + renderHook, + type RenderHookResult, + type WrapperComponent, +} from '@testing-library/react-hooks'; +import type { ObservabilityAIAssistantService, PendingMessage } from '../types'; +import type { DeeplyMockedKeys } from '@kbn/utility-types-jest'; +import { ObservabilityAIAssistantProvider } from '../context/observability_ai_assistant_provider'; +import * as useKibanaModule from './use_kibana'; +import { Message, MessageRole } from '../../common'; +import { ChatState } from './use_chat'; +import { createMockChatService } from '../service/create_mock_chat_service'; +import { Subject } from 'rxjs'; +import { EMPTY_CONVERSATION_TITLE } from '../i18n'; +import { omit } from 'lodash'; + +let hookResult: RenderHookResult; + +type MockedService = DeeplyMockedKeys; + +const mockService: MockedService = { + callApi: jest.fn(), + getCurrentUser: jest.fn(), + getLicense: jest.fn(), + getLicenseManagementLocator: jest.fn(), + isEnabled: jest.fn(), + start: jest.fn(), +}; + +const mockChatService = createMockChatService(); + +const addErrorMock = jest.fn(); + +jest.spyOn(useKibanaModule, 'useKibana').mockReturnValue({ + services: { + notifications: { + toasts: { + addError: addErrorMock, + }, + }, + }, +} as any); + +describe('useConversation', () => { + let wrapper: WrapperComponent; + + beforeEach(() => { + jest.clearAllMocks(); + wrapper = ({ children }) => ( + + {children} + + ); + }); + + describe('with initial messages and a conversation id', () => { + beforeEach(() => { + hookResult = renderHook(useConversation, { + initialProps: { + chatService: mockChatService, + connectorId: 'my-connector', + initialMessages: [ + { + '@timestamp': new Date().toISOString(), + message: { content: '', role: MessageRole.User }, + }, + ], + initialConversationId: 'foo', + }, + wrapper, + }); + }); + it('throws an error', () => { + expect(hookResult.result.error).toBeTruthy(); + }); + }); + + describe('without initial messages and a conversation id', () => { + beforeEach(() => { + hookResult = renderHook(useConversation, { + initialProps: { + chatService: mockChatService, + connectorId: 'my-connector', + }, + wrapper, + }); + }); + + it('returns only the system message', () => { + expect(hookResult.result.current.messages).toEqual([ + { + '@timestamp': expect.any(String), + message: { + content: '', + role: MessageRole.System, + }, + }, + ]); + }); + + it('returns a ready state', () => { + expect(hookResult.result.current.state).toBe(ChatState.Ready); + }); + + it('does not call the fetch api', () => { + expect(mockService.callApi).not.toHaveBeenCalled(); + }); + }); + + describe('with initial messages', () => { + beforeEach(() => { + hookResult = renderHook(useConversation, { + initialProps: { + chatService: mockChatService, + connectorId: 'my-connector', + initialMessages: [ + { + '@timestamp': new Date().toISOString(), + message: { + content: 'Test', + role: MessageRole.User, + }, + }, + ], + }, + wrapper, + }); + }); + + it('returns the system message and the initial messages', () => { + expect(hookResult.result.current.messages).toEqual([ + { + '@timestamp': expect.any(String), + message: { + content: '', + role: MessageRole.System, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: 'Test', + role: MessageRole.User, + }, + }, + ]); + }); + }); + + describe('with a conversation id that successfully loads', () => { + beforeEach(async () => { + mockService.callApi.mockResolvedValueOnce({ + conversation: { + id: 'my-conversation-id', + }, + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: 'System', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'User', + }, + }, + ], + }); + + hookResult = renderHook(useConversation, { + initialProps: { + chatService: mockChatService, + connectorId: 'my-connector', + initialConversationId: 'my-conversation-id', + }, + wrapper, + }); + + await act(async () => {}); + }); + + it('returns the loaded conversation', () => { + expect(hookResult.result.current.conversation.value).toEqual({ + conversation: { + id: 'my-conversation-id', + }, + messages: [ + { + '@timestamp': expect.any(String), + message: { + content: 'System', + role: MessageRole.System, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: 'User', + role: MessageRole.User, + }, + }, + ], + }); + }); + + it('sets messages to the messages of the conversation', () => { + expect(hookResult.result.current.messages).toEqual([ + { + '@timestamp': expect.any(String), + message: { + content: expect.any(String), + role: MessageRole.System, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: 'User', + role: MessageRole.User, + }, + }, + ]); + }); + + it('overrides the system message', () => { + expect(hookResult.result.current.messages[0].message.content).toBe(''); + }); + }); + + describe('with a conversation id that fails to load', () => { + beforeEach(async () => { + mockService.callApi.mockRejectedValueOnce(new Error('failed to load')); + + hookResult = renderHook(useConversation, { + initialProps: { + chatService: mockChatService, + connectorId: 'my-connector', + initialConversationId: 'my-conversation-id', + }, + wrapper, + }); + + await act(async () => {}); + }); + + it('returns an error', () => { + expect(hookResult.result.current.conversation.error).toBeTruthy(); + }); + + it('resets the messages', () => { + expect(hookResult.result.current.messages.length).toBe(1); + }); + }); + + describe('when chat completes without an initial conversation id', () => { + const subject: Subject = new Subject(); + beforeEach(() => { + mockService.callApi.mockImplementation(async (endpoint, request) => ({ + conversation: { + id: 'my-conversation-id', + }, + messages: (request as any).params.body.messages, + })); + + hookResult = renderHook(useConversation, { + initialProps: { + chatService: mockChatService, + connectorId: 'my-connector', + initialMessages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'Hello', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + content: 'Goodbye', + }, + }, + ], + }, + wrapper, + }); + + mockChatService.chat.mockImplementationOnce(() => { + return subject; + }); + }); + + it('the conversation is created including the initial messages', async () => { + const expectedMessages = [ + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.System, + content: '', + }, + }, + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.User, + content: 'Hello', + }, + }, + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.Assistant, + content: 'Goodbye', + }, + }, + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.User, + content: 'Hello again', + }, + }, + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.Assistant, + content: 'Goodbye again', + }, + }, + ]; + + act(() => { + hookResult.result.current.next( + hookResult.result.current.messages.concat({ + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'Hello again', + }, + }) + ); + subject.next({ + message: { + role: MessageRole.Assistant, + content: 'Goodbye again', + }, + }); + subject.complete(); + }); + + await act(async () => {}); + + expect(mockService.callApi.mock.calls[0]).toEqual([ + 'POST /internal/observability_ai_assistant/conversation', + { + params: { + body: { + conversation: { + '@timestamp': expect.any(String), + conversation: { + title: EMPTY_CONVERSATION_TITLE, + }, + messages: expectedMessages, + labels: {}, + numeric_labels: {}, + public: false, + }, + }, + }, + signal: null, + }, + ]); + + expect(hookResult.result.current.messages).toEqual(expectedMessages); + }); + }); + + describe('when chat completes with an initial conversation id', () => { + let subject: Subject; + + const initialMessages: Message[] = [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: '', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'user', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + content: 'assistant', + }, + }, + ]; + + beforeEach(async () => { + mockService.callApi.mockImplementation(async (endpoint, request) => ({ + '@timestamp': new Date().toISOString(), + conversation: { + id: 'my-conversation-id', + title: EMPTY_CONVERSATION_TITLE, + }, + labels: {}, + numeric_labels: {}, + public: false, + messages: initialMessages, + })); + + hookResult = renderHook(useConversation, { + initialProps: { + chatService: mockChatService, + connectorId: 'my-connector', + initialConversationId: 'my-conversation-id', + }, + wrapper, + }); + + await act(async () => {}); + }); + + it('the conversation is loadeded', async () => { + expect(mockService.callApi.mock.calls[0]).toEqual([ + 'GET /internal/observability_ai_assistant/conversation/{conversationId}', + { + signal: expect.anything(), + params: { + path: { + conversationId: 'my-conversation-id', + }, + }, + }, + ]); + + expect(hookResult.result.current.messages).toEqual( + initialMessages.map((msg) => ({ ...msg, '@timestamp': expect.any(String) })) + ); + }); + + describe('after chat completes', () => { + const nextUserMessage: Message = { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'Hello again', + }, + }; + + const nextAssistantMessage: Message = { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + content: 'Goodbye again', + }, + }; + + beforeEach(async () => { + mockService.callApi.mockClear(); + subject = new Subject(); + + mockChatService.chat.mockImplementationOnce(() => { + return subject; + }); + + act(() => { + hookResult.result.current.next( + hookResult.result.current.messages.concat(nextUserMessage) + ); + subject.next(omit(nextAssistantMessage, '@timestamp')); + subject.complete(); + }); + + await act(async () => {}); + }); + + it('saves the updated message', () => { + expect(mockService.callApi.mock.calls[0]).toEqual([ + 'PUT /internal/observability_ai_assistant/conversation/{conversationId}', + { + params: { + path: { + conversationId: 'my-conversation-id', + }, + body: { + conversation: { + '@timestamp': expect.any(String), + conversation: { + title: EMPTY_CONVERSATION_TITLE, + id: 'my-conversation-id', + }, + messages: initialMessages + .concat([nextUserMessage, nextAssistantMessage]) + .map((msg) => ({ ...msg, '@timestamp': expect.any(String) })), + labels: {}, + numeric_labels: {}, + public: false, + }, + }, + }, + signal: null, + }, + ]); + }); + }); + }); + + describe('when the title is updated', () => { + it('the conversation is saved with the updated title', () => {}); + }); +}); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts index 6970c53e28bf19..8f8581874c27ba 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts @@ -6,145 +6,127 @@ */ import { i18n } from '@kbn/i18n'; import { merge, omit } from 'lodash'; -import { Dispatch, SetStateAction, useCallback, useMemo, useState } from 'react'; -import { type Conversation, type Message } from '../../common'; -import { ConversationCreateRequest, MessageRole } from '../../common/types'; -import { getAssistantSetupMessage } from '../service/get_assistant_setup_message'; -import { ObservabilityAIAssistantChatService } from '../types'; +import { useState } from 'react'; +import type { Conversation, Message } from '../../common'; +import type { ConversationCreateRequest } from '../../common/types'; +import { EMPTY_CONVERSATION_TITLE } from '../i18n'; +import type { ObservabilityAIAssistantChatService } from '../types'; import { useAbortableAsync, type AbortableAsyncState } from './use_abortable_async'; +import { useChat, UseChatResult } from './use_chat'; import { useKibana } from './use_kibana'; import { useObservabilityAIAssistant } from './use_observability_ai_assistant'; -import { createNewConversation } from './use_timeline'; +import { useOnce } from './use_once'; + +function createNewConversation({ + title = EMPTY_CONVERSATION_TITLE, +}: { title?: string } = {}): ConversationCreateRequest { + return { + '@timestamp': new Date().toISOString(), + messages: [], + conversation: { + title, + }, + labels: {}, + numeric_labels: {}, + public: false, + }; +} + +export interface UseConversationProps { + initialConversationId?: string; + initialMessages?: Message[]; + initialTitle?: string; + chatService: ObservabilityAIAssistantChatService; + connectorId: string | undefined; + onConversationUpdate?: (conversation: Conversation) => void; +} + +export type UseConversationResult = { + conversation: AbortableAsyncState; + saveTitle: (newTitle: string) => void; +} & Omit; + +const DEFAULT_INITIAL_MESSAGES: Message[] = []; export function useConversation({ - conversationId, + initialConversationId: initialConversationIdFromProps, + initialMessages: initialMessagesFromProps = DEFAULT_INITIAL_MESSAGES, + initialTitle: initialTitleFromProps, chatService, connectorId, - initialMessages = [], -}: { - conversationId?: string; - chatService?: ObservabilityAIAssistantChatService; // will eventually resolve to a non-nullish value - connectorId: string | undefined; - initialMessages?: Message[]; -}): { - conversation: AbortableAsyncState; - displayedMessages: Message[]; - setDisplayedMessages: Dispatch>; - getSystemMessage: () => Message; - save: (messages: Message[], handleRefreshConversations?: () => void) => Promise; - saveTitle: ( - title: string, - handleRefreshConversations?: () => void - ) => Promise; -} { + onConversationUpdate, +}: UseConversationProps): UseConversationResult { const service = useObservabilityAIAssistant(); const { services: { notifications }, } = useKibana(); - const [displayedMessages, setDisplayedMessages] = useState(initialMessages); - - const getSystemMessage = useCallback(() => { - return getAssistantSetupMessage({ contexts: chatService?.getContexts() || [] }); - }, [chatService]); + const initialConversationId = useOnce(initialConversationIdFromProps); + const initialMessages = useOnce(initialMessagesFromProps); + const initialTitle = useOnce(initialTitleFromProps); - const displayedMessagesWithHardcodedSystemMessage = useMemo(() => { - if (!chatService) { - return displayedMessages; - } + if (initialMessages.length && initialConversationId) { + throw new Error('Cannot set initialMessages if initialConversationId is set'); + } - const systemMessage = getSystemMessage(); - - if (displayedMessages[0]?.message.role === MessageRole.User) { - return [systemMessage, ...displayedMessages]; - } - - return [systemMessage, ...displayedMessages.slice(1)]; - }, [displayedMessages, chatService, getSystemMessage]); + const update = (nextConversationObject: Conversation) => { + return service + .callApi(`PUT /internal/observability_ai_assistant/conversation/{conversationId}`, { + signal: null, + params: { + path: { + conversationId: nextConversationObject.conversation.id, + }, + body: { + conversation: merge( + { + '@timestamp': nextConversationObject['@timestamp'], + conversation: { + id: nextConversationObject.conversation.id, + }, + }, + omit(nextConversationObject, 'conversation.last_updated', 'namespace', 'user') + ), + }, + }, + }) + .catch((err) => { + notifications.toasts.addError(err, { + title: i18n.translate('xpack.observabilityAiAssistant.errorUpdatingConversation', { + defaultMessage: 'Could not update conversation', + }), + }); + throw err; + }); + }; - const conversation: AbortableAsyncState = - useAbortableAsync( - ({ signal }) => { - if (!conversationId) { - const nextConversation = createNewConversation({ - contexts: chatService?.getContexts() || [], - }); - setDisplayedMessages(nextConversation.messages); - return nextConversation; - } + const save = (nextMessages: Message[]) => { + const conversationObject = conversation.value!; - return service - .callApi('GET /internal/observability_ai_assistant/conversation/{conversationId}', { - signal, - params: { path: { conversationId } }, - }) - .then((nextConversation) => { - setDisplayedMessages(nextConversation.messages); - return nextConversation; - }) - .catch((error) => { - setDisplayedMessages([]); - throw error; - }); - }, - [conversationId, chatService] - ); + const nextConversationObject = merge({}, omit(conversationObject, 'messages'), { + messages: nextMessages, + }); - return { - conversation, - displayedMessages: displayedMessagesWithHardcodedSystemMessage, - setDisplayedMessages, - getSystemMessage, - save: (messages: Message[], handleRefreshConversations?: () => void) => { - const conversationObject = conversation.value!; - - return conversationId - ? service - .callApi(`PUT /internal/observability_ai_assistant/conversation/{conversationId}`, { - signal: null, - params: { - path: { - conversationId, - }, - body: { - conversation: merge( - { - '@timestamp': conversationObject['@timestamp'], - conversation: { - id: conversationId, - }, - }, - omit( - conversationObject, - 'conversation.last_updated', - 'namespace', - 'user', - 'messages' - ), - { messages } - ), - }, - }, - }) - .catch((err) => { - notifications.toasts.addError(err, { - title: i18n.translate('xpack.observabilityAiAssistant.errorUpdatingConversation', { - defaultMessage: 'Could not update conversation', - }), - }); - throw err; - }) + return ( + initialConversationId + ? update( + merge( + { conversation: { id: initialConversationId } }, + nextConversationObject + ) as Conversation + ) : service .callApi(`POST /internal/observability_ai_assistant/conversation`, { signal: null, params: { body: { - conversation: merge({}, conversationObject, { messages }), + conversation: nextConversationObject, }, }, }) .then((nextConversation) => { + setDisplayedConversationId(nextConversation.conversation.id); if (connectorId) { service .callApi( @@ -162,7 +144,7 @@ export function useConversation({ } ) .then(() => { - handleRefreshConversations?.(); + onConversationUpdate?.(nextConversation); return conversation.refresh(); }); } @@ -175,27 +157,64 @@ export function useConversation({ }), }); throw err; - }); + }) + ).then((nextConversation) => { + onConversationUpdate?.(nextConversation); + return nextConversation; + }); + }; + + const { next, messages, setMessages, state, stop } = useChat({ + initialMessages, + chatService, + connectorId, + onChatComplete: (nextMessages) => { + save(nextMessages); }, - saveTitle: (title: string, handleRefreshConversations?: () => void) => { - if (conversationId) { + }); + + const [displayedConversationId, setDisplayedConversationId] = useState(initialConversationId); + + const conversation: AbortableAsyncState = + useAbortableAsync( + ({ signal }) => { + if (!displayedConversationId) { + const nextConversation = createNewConversation({ title: initialTitle }); + return nextConversation; + } + return service - .callApi('PUT /internal/observability_ai_assistant/conversation/{conversationId}/title', { - signal: null, - params: { - path: { - conversationId, - }, - body: { - title, - }, - }, + .callApi('GET /internal/observability_ai_assistant/conversation/{conversationId}', { + signal, + params: { path: { conversationId: displayedConversationId } }, + }) + .then((nextConversation) => { + setMessages(nextConversation.messages); + return nextConversation; }) - .then(() => { - handleRefreshConversations?.(); + .catch((error) => { + setMessages([]); + throw error; }); + }, + [initialConversationId, initialTitle], + { + defaultValue: () => { + if (!initialConversationId) { + const nextConversation = createNewConversation({ title: initialTitle }); + return nextConversation; + } + return undefined; + }, } - return Promise.resolve(); - }, + ); + + return { + conversation, + state, + next, + stop, + messages, + saveTitle: () => {}, }; } diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_knowledge_base.tsx b/x-pack/plugins/observability_ai_assistant/public/hooks/use_knowledge_base.tsx index d7c76e0ab4f85f..24c83d3fa8eb48 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_knowledge_base.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_knowledge_base.tsx @@ -57,6 +57,7 @@ export function useKnowledgeBase(): UseKnowledgeBaseResult { text: i18n.translate('xpack.observabilityAiAssistant.knowledgeBaseReadyContentReload', { defaultMessage: 'A page reload is needed to be able to use it.', }), + toastLifeTimeMs: Number.MAX_VALUE, }); }) .catch((error) => { diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_once.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_once.ts new file mode 100644 index 00000000000000..00dab01456af07 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_once.ts @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { useRef } from 'react'; + +export function useOnce(variable: T): T { + const ref = useRef(variable); + + if (ref.current !== variable) { + // eslint-disable-next-line no-console + console.trace( + `Variable changed from ${ref.current} to ${variable}, but only the initial value will be taken into account` + ); + } + + return ref.current; +} diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts deleted file mode 100644 index 6ad1d0746a517e..00000000000000 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts +++ /dev/null @@ -1,611 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -import type { FindActionResult } from '@kbn/actions-plugin/server'; -import { AbortError } from '@kbn/kibana-utils-plugin/common'; -import { - act, - renderHook, - type Renderer, - type RenderHookResult, -} from '@testing-library/react-hooks'; -import { BehaviorSubject, Subject } from 'rxjs'; -import { MessageRole } from '../../common'; -import { ChatTimelineItem } from '../components/chat/chat_timeline'; -import type { PendingMessage } from '../types'; -import { useTimeline, UseTimelineResult } from './use_timeline'; - -type HookProps = Parameters[0]; - -const WAIT_OPTIONS = { timeout: 1500 }; - -jest.mock('./use_kibana', () => ({ - useKibana: () => ({ - services: { - notifications: { - toasts: { - addError: jest.fn(), - }, - }, - }, - }), -})); - -describe('useTimeline', () => { - let hookResult: RenderHookResult>; - - describe('with an empty conversation', () => { - beforeAll(() => { - hookResult = renderHook((props) => useTimeline(props), { - initialProps: { - connectors: { - loading: false, - selectedConnector: 'OpenAI', - selectConnector: () => {}, - connectors: [{ id: 'OpenAI' }] as FindActionResult[], - }, - chatService: {}, - messages: [], - onChatComplete: jest.fn(), - onChatUpdate: jest.fn(), - } as unknown as HookProps, - }); - }); - it('renders the correct timeline items', () => { - expect(hookResult.result.current.items.length).toEqual(1); - - expect(hookResult.result.current.items[0]).toEqual({ - display: { - collapsed: false, - hide: false, - }, - actions: { - canCopy: false, - canEdit: false, - canRegenerate: false, - canGiveFeedback: false, - }, - role: MessageRole.User, - title: 'started a conversation', - loading: false, - id: expect.any(String), - }); - }); - }); - - describe('with an existing conversation', () => { - beforeAll(() => { - hookResult = renderHook((props) => useTimeline(props), { - initialProps: { - messages: [ - { - message: { - role: MessageRole.System, - content: 'You are a helpful assistant for Elastic Observability', - }, - }, - { - message: { - role: MessageRole.User, - content: 'hello', - }, - }, - { - message: { - role: MessageRole.Assistant, - content: '', - function_call: { - name: 'recall', - trigger: MessageRole.User, - }, - }, - }, - { - message: { - name: 'recall', - role: MessageRole.User, - content: '', - }, - }, - { - message: { - content: 'goodbye', - function_call: { - name: '', - arguments: '', - trigger: MessageRole.Assistant, - }, - role: MessageRole.Assistant, - }, - }, - ], - connectors: { - selectedConnector: 'foo', - }, - chatService: { - chat: () => {}, - hasRenderFunction: () => {}, - hasFunction: () => {}, - }, - } as unknown as HookProps, - }); - }); - it('renders the correct timeline items', () => { - expect(hookResult.result.current.items.length).toEqual(4); - - expect(hookResult.result.current.items[1]).toEqual({ - actions: { canCopy: true, canEdit: true, canGiveFeedback: false, canRegenerate: false }, - content: 'hello', - currentUser: undefined, - display: { collapsed: false, hide: false }, - element: undefined, - function_call: undefined, - id: expect.any(String), - loading: false, - role: MessageRole.User, - title: '', - }); - - expect(hookResult.result.current.items[3]).toEqual({ - actions: { canCopy: true, canEdit: false, canGiveFeedback: false, canRegenerate: true }, - content: 'goodbye', - currentUser: undefined, - display: { collapsed: false, hide: false }, - element: undefined, - function_call: { - arguments: '', - name: '', - trigger: MessageRole.Assistant, - }, - id: expect.any(String), - loading: false, - role: MessageRole.Assistant, - title: '', - }); - - // Items that are function calls are collapsed into an array. - - // 'title' is a component. This throws Jest for a loop. - const collapsedItemsWithoutTitle = ( - hookResult.result.current.items[2] as ChatTimelineItem[] - ).map(({ title, ...rest }) => rest); - - expect(collapsedItemsWithoutTitle).toEqual([ - { - display: { - collapsed: true, - hide: false, - }, - actions: { - canCopy: true, - canEdit: true, - canRegenerate: false, - canGiveFeedback: false, - }, - currentUser: undefined, - function_call: { - name: 'recall', - trigger: MessageRole.User, - }, - role: MessageRole.User, - content: `\`\`\` -{ - \"name\": \"recall\" -} -\`\`\``, - loading: false, - id: expect.any(String), - }, - { - display: { - collapsed: true, - hide: false, - }, - actions: { - canCopy: true, - canEdit: false, - canRegenerate: false, - canGiveFeedback: false, - }, - currentUser: undefined, - function_call: undefined, - role: MessageRole.User, - content: `\`\`\` -{} -\`\`\``, - loading: false, - id: expect.any(String), - }, - ]); - }); - }); - - describe('when submitting a new prompt', () => { - let subject: Subject; - - let props: Omit & { - onChatUpdate: jest.MockedFn; - onChatComplete: jest.MockedFn; - chatService: Omit & { - executeFunction: jest.MockedFn; - }; - }; - - beforeEach(() => { - props = { - messages: [], - connectors: { - selectedConnector: 'foo', - }, - chatService: { - chat: jest.fn().mockImplementation(() => { - subject = new BehaviorSubject({ - message: { - role: MessageRole.Assistant, - content: '', - }, - }); - return subject; - }), - executeFunction: jest.fn(), - hasFunction: jest.fn(), - hasRenderFunction: jest.fn(), - }, - onChatUpdate: jest.fn().mockImplementation((messages) => { - props = { ...props, messages }; - hookResult.rerender(props as unknown as HookProps); - }), - onChatComplete: jest.fn(), - } as any; - - hookResult = renderHook((nextProps) => useTimeline(nextProps), { - initialProps: props as unknown as HookProps, - }); - }); - - describe("and it's loading", () => { - beforeEach(() => { - act(() => { - hookResult.result.current.onSubmit({ - '@timestamp': new Date().toISOString(), - message: { role: MessageRole.User, content: 'Hello' }, - }); - }); - }); - - it('adds two items of which the last one is loading', async () => { - expect((hookResult.result.current.items[0] as ChatTimelineItem).role).toEqual( - MessageRole.User - ); - expect((hookResult.result.current.items[1] as ChatTimelineItem).role).toEqual( - MessageRole.User - ); - - expect((hookResult.result.current.items[2] as ChatTimelineItem).role).toEqual( - MessageRole.Assistant - ); - - expect(hookResult.result.current.items[1]).toMatchObject({ - role: MessageRole.User, - content: 'Hello', - loading: false, - }); - - expect(hookResult.result.current.items[2]).toMatchObject({ - role: MessageRole.Assistant, - content: '', - loading: true, - actions: { - canRegenerate: false, - canGiveFeedback: false, - }, - }); - - expect(hookResult.result.current.items.length).toBe(3); - - expect(hookResult.result.current.items[2]).toMatchObject({ - role: MessageRole.Assistant, - content: '', - loading: true, - actions: { - canRegenerate: false, - canGiveFeedback: false, - }, - }); - }); - - describe('and it pushes the next part', () => { - beforeEach(() => { - act(() => { - subject.next({ message: { role: MessageRole.Assistant, content: 'Goodbye' } }); - }); - }); - - it('adds the partial response', () => { - expect(hookResult.result.current.items[2]).toMatchObject({ - role: MessageRole.Assistant, - content: 'Goodbye', - loading: true, - actions: { - canRegenerate: false, - canGiveFeedback: false, - }, - }); - }); - - describe('and it completes', () => { - beforeEach(async () => { - act(() => { - subject.complete(); - }); - - await hookResult.waitForNextUpdate(WAIT_OPTIONS); - }); - - it('adds the completed message', () => { - expect(hookResult.result.current.items[2]).toMatchObject({ - role: MessageRole.Assistant, - content: 'Goodbye', - loading: false, - actions: { - canRegenerate: true, - canGiveFeedback: false, - }, - }); - }); - - describe('and the user edits a message', () => { - beforeEach(() => { - act(() => { - hookResult.result.current.onEdit( - hookResult.result.current.items[1] as ChatTimelineItem, - { - '@timestamp': new Date().toISOString(), - message: { content: 'Edited message', role: MessageRole.User }, - } - ); - subject.next({ message: { role: MessageRole.Assistant, content: '' } }); - subject.complete(); - }); - }); - - it('calls onChatUpdate with the edited message', () => { - expect(hookResult.result.current.items.length).toEqual(4); - expect((hookResult.result.current.items[2] as ChatTimelineItem).content).toEqual( - 'Edited message' - ); - expect((hookResult.result.current.items[3] as ChatTimelineItem).content).toEqual(''); - }); - }); - }); - }); - - describe('and it is being aborted', () => { - beforeEach(() => { - act(() => { - subject.next({ message: { role: MessageRole.Assistant, content: 'My partial' } }); - subject.next({ - message: { - role: MessageRole.Assistant, - content: 'My partial', - }, - aborted: true, - error: new AbortError(), - }); - subject.complete(); - }); - }); - - it('adds the partial response', async () => { - expect(hookResult.result.current.items.length).toBe(3); - - expect(hookResult.result.current.items[2]).toEqual({ - actions: { - canEdit: false, - canRegenerate: true, - canGiveFeedback: false, - canCopy: true, - }, - display: { - collapsed: false, - hide: false, - }, - content: 'My partial', - id: expect.any(String), - loading: false, - title: '', - role: MessageRole.Assistant, - error: expect.any(AbortError), - }); - }); - - describe('and it is being regenerated', () => { - beforeEach(() => { - act(() => { - hookResult.result.current.onRegenerate( - hookResult.result.current.items[2] as ChatTimelineItem - ); - subject.next({ message: { role: MessageRole.Assistant, content: '' } }); - }); - }); - - it('updates the last item in the array to be loading', () => { - expect(hookResult.result.current.items.length).toEqual(3); - - expect(hookResult.result.current.items[2]).toEqual({ - display: { - hide: false, - collapsed: false, - }, - actions: { - canCopy: true, - canEdit: false, - canRegenerate: false, - canGiveFeedback: false, - }, - content: '', - id: expect.any(String), - loading: true, - title: '', - role: MessageRole.Assistant, - }); - }); - - describe('and it is regenerated again', () => { - beforeEach(async () => { - act(() => { - hookResult.result.current.onStopGenerating(); - }); - - act(() => { - hookResult.result.current.onRegenerate( - hookResult.result.current.items[2] as ChatTimelineItem - ); - }); - }); - - it('updates the last item to be not loading again', async () => { - expect(hookResult.result.current.items.length).toBe(3); - - expect(hookResult.result.current.items[2]).toEqual({ - actions: { - canCopy: true, - canEdit: false, - canRegenerate: false, - canGiveFeedback: false, - }, - display: { - collapsed: false, - hide: false, - }, - content: '', - id: expect.any(String), - loading: true, - title: '', - role: MessageRole.Assistant, - }); - - act(() => { - subject.next({ message: { role: MessageRole.Assistant, content: 'Regenerated' } }); - subject.complete(); - }); - - await hookResult.waitForNextUpdate(WAIT_OPTIONS); - - expect(hookResult.result.current.items.length).toBe(3); - - expect(hookResult.result.current.items[2]).toEqual({ - display: { - collapsed: false, - hide: false, - }, - actions: { - canCopy: true, - canEdit: false, - canRegenerate: true, - canGiveFeedback: false, - }, - content: 'Regenerated', - currentUser: undefined, - function_call: undefined, - id: expect.any(String), - element: undefined, - loading: false, - title: '', - role: MessageRole.Assistant, - }); - }); - }); - }); - }); - - describe('and a function call is returned', () => { - it('the function call is executed and its response is sent as a user reply', async () => { - jest.clearAllMocks(); - - act(() => { - subject.next({ - message: { - role: MessageRole.Assistant, - function_call: { - trigger: MessageRole.Assistant, - name: 'my_function', - arguments: '{}', - }, - }, - }); - subject.complete(); - }); - - props.chatService.executeFunction.mockResolvedValueOnce({ - content: { - message: 'my-response', - }, - }); - - await hookResult.waitForNextUpdate(WAIT_OPTIONS); - - expect(props.onChatUpdate).toHaveBeenCalledTimes(2); - - expect( - props.onChatUpdate.mock.calls[0][0].map( - (msg) => msg.message.content || msg.message.function_call?.name - ) - ).toEqual(['Hello', 'my_function']); - - expect( - props.onChatUpdate.mock.calls[1][0].map( - (msg) => msg.message.content || msg.message.function_call?.name - ) - ).toEqual(['Hello', 'my_function', JSON.stringify({ message: 'my-response' })]); - - expect(props.onChatComplete).not.toHaveBeenCalled(); - - expect(props.chatService.executeFunction).toHaveBeenCalledWith({ - name: 'my_function', - args: '{}', - connectorId: 'foo', - messages: [ - { - '@timestamp': expect.any(String), - message: { - content: 'Hello', - role: MessageRole.User, - }, - }, - ], - signal: expect.any(Object), - }); - - act(() => { - subject.next({ - message: { - role: MessageRole.Assistant, - content: 'looks like my-function returned my-response', - }, - }); - subject.complete(); - }); - - await hookResult.waitForNextUpdate(WAIT_OPTIONS); - - expect(props.onChatComplete).toHaveBeenCalledTimes(1); - - expect( - props.onChatComplete.mock.calls[0][0].map( - (msg) => msg.message.content || msg.message.function_call?.name - ) - ).toEqual([ - 'Hello', - 'my_function', - JSON.stringify({ message: 'my-response' }), - 'looks like my-function returned my-response', - ]); - }); - }); - }); - }); -}); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts deleted file mode 100644 index 64d82cabb94370..00000000000000 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts +++ /dev/null @@ -1,387 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { i18n } from '@kbn/i18n'; -import { AbortError } from '@kbn/kibana-utils-plugin/common'; -import type { AuthenticatedUser } from '@kbn/security-plugin/common'; -import { flatten, last } from 'lodash'; -import { useEffect, useMemo, useRef, useState } from 'react'; -import usePrevious from 'react-use/lib/usePrevious'; -import { isObservable, Observable, Subscription } from 'rxjs'; -import { - ContextDefinition, - MessageRole, - type ConversationCreateRequest, - type Message, -} from '../../common/types'; -import type { ChatPromptEditorProps } from '../components/chat/chat_prompt_editor'; -import type { ChatTimelineItem, ChatTimelineProps } from '../components/chat/chat_timeline'; -import { ChatActionClickType } from '../components/chat/types'; -import { EMPTY_CONVERSATION_TITLE } from '../i18n'; -import type { ObservabilityAIAssistantChatService, PendingMessage } from '../types'; -import { - getTimelineItemsfromConversation, - StartedFrom, -} from '../utils/get_timeline_items_from_conversation'; -import type { UseGenAIConnectorsResult } from './use_genai_connectors'; -import { useKibana } from './use_kibana'; - -export function createNewConversation({ - contexts, -}: { - contexts: ContextDefinition[]; -}): ConversationCreateRequest { - return { - '@timestamp': new Date().toISOString(), - messages: [], - conversation: { - title: EMPTY_CONVERSATION_TITLE, - }, - labels: {}, - numeric_labels: {}, - public: false, - }; -} - -export type UseTimelineResult = Pick< - ChatTimelineProps, - 'onEdit' | 'onFeedback' | 'onRegenerate' | 'onStopGenerating' | 'onActionClick' | 'items' -> & - Pick; - -export function useTimeline({ - messages, - connectors, - conversationId, - currentUser, - chatService, - startedFrom, - onChatUpdate, - onChatComplete, -}: { - messages: Message[]; - conversationId?: string; - connectors: UseGenAIConnectorsResult; - currentUser?: Pick; - chatService: ObservabilityAIAssistantChatService; - startedFrom?: StartedFrom; - onChatUpdate: (messages: Message[]) => void; - onChatComplete: (messages: Message[]) => void; -}): UseTimelineResult { - const connectorId = connectors.selectedConnector; - - const hasConnector = !!connectorId; - - const { - services: { notifications }, - } = useKibana(); - - const conversationItems = useMemo(() => { - const items = getTimelineItemsfromConversation({ - currentUser, - chatService, - hasConnector, - messages, - startedFrom, - }); - - return items; - }, [currentUser, chatService, hasConnector, messages, startedFrom]); - - const [subscription, setSubscription] = useState(); - - const controllerRef = useRef(new AbortController()); - - const [pendingMessage, setPendingMessage] = useState(); - - const [isFunctionLoading, setIsFunctionLoading] = useState(false); - - const prevConversationId = usePrevious(conversationId); - - useEffect(() => { - if (prevConversationId !== conversationId && pendingMessage?.error) { - setPendingMessage(undefined); - } - }, [conversationId, pendingMessage?.error, prevConversationId]); - - function chat( - nextMessages: Message[], - response$: Observable | undefined = undefined - ): Promise { - const controller = new AbortController(); - - return new Promise(async (resolve, reject) => { - try { - if (!connectorId) { - reject(new Error('Can not add a message without a connector')); - return; - } - - const isStartOfConversation = - nextMessages.some((message) => message.message.role === MessageRole.Assistant) === false; - - if (isStartOfConversation && chatService.hasFunction('recall')) { - nextMessages = nextMessages.concat({ - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.Assistant, - content: '', - function_call: { - name: 'recall', - arguments: JSON.stringify({ queries: [], contexts: [] }), - trigger: MessageRole.User, - }, - }, - }); - } - - onChatUpdate(nextMessages); - const lastMessage = last(nextMessages); - if (lastMessage?.message.function_call?.name) { - // the user has edited a function suggestion, no need to talk to the LLM - resolve(undefined); - return; - } - - response$ = - response$ || - chatService!.chat({ - messages: nextMessages, - connectorId, - }); - let pendingMessageLocal = pendingMessage; - const nextSubscription = response$.subscribe({ - next: (nextPendingMessage) => { - pendingMessageLocal = nextPendingMessage; - setPendingMessage(() => nextPendingMessage); - }, - error: reject, - complete: () => { - const error = pendingMessageLocal?.error; - if (error) { - notifications.toasts.addError(error, { - title: i18n.translate('xpack.observabilityAiAssistant.failedToLoadResponse', { - defaultMessage: 'Failed to load response from the AI Assistant', - }), - }); - } - resolve(pendingMessageLocal!); - }, - }); - setSubscription(() => { - controllerRef.current = controller; - return nextSubscription; - }); - } catch (error) { - reject(error); - } - }).then(async (reply) => { - if (reply?.error) { - return nextMessages; - } - if (reply?.aborted) { - return nextMessages; - } - - setPendingMessage(undefined); - - const messagesAfterChat = reply - ? nextMessages.concat({ - '@timestamp': new Date().toISOString(), - message: { - ...reply.message, - }, - }) - : nextMessages; - - onChatUpdate(messagesAfterChat); - - const lastMessage = last(messagesAfterChat); - - if (lastMessage?.message.function_call?.name) { - const name = lastMessage.message.function_call.name; - - setIsFunctionLoading(true); - - try { - let message = await chatService!.executeFunction({ - name, - args: lastMessage.message.function_call.arguments, - messages: messagesAfterChat.slice(0, -1), - signal: controller.signal, - connectorId: connectorId!, - }); - - let nextResponse$: Observable | undefined; - - if (isObservable(message)) { - nextResponse$ = message; - message = { content: '', data: '' }; - } - - return await chat( - messagesAfterChat.concat({ - '@timestamp': new Date().toISOString(), - message: { - name, - role: MessageRole.User, - content: JSON.stringify(message.content), - data: JSON.stringify(message.data), - }, - }), - nextResponse$ - ); - } catch (error) { - return await chat( - messagesAfterChat.concat({ - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.User, - name, - content: JSON.stringify({ - message: error.toString(), - error, - }), - }, - }) - ); - } finally { - setIsFunctionLoading(false); - } - } - - return messagesAfterChat; - }); - } - - const itemsWithAddedLoadingStates = useMemo(() => { - // While we're loading we add an empty loading chat item: - if (pendingMessage || isFunctionLoading) { - const nextItems = conversationItems.concat({ - id: '', - actions: { - canCopy: true, - canEdit: false, - canGiveFeedback: false, - canRegenerate: pendingMessage?.aborted || !!pendingMessage?.error, - }, - display: { - collapsed: false, - hide: pendingMessage?.message.role === MessageRole.System, - }, - content: pendingMessage?.message.content, - currentUser, - error: pendingMessage?.error, - function_call: pendingMessage?.message.function_call, - loading: !pendingMessage?.aborted && !pendingMessage?.error, - role: pendingMessage?.message.role || MessageRole.Assistant, - title: '', - }); - - return nextItems; - } - - if (!isFunctionLoading) { - return conversationItems; - } - - return conversationItems.map((item, index) => { - // When we're done loading we remove the placeholder item again - if (index < conversationItems.length - 1) { - return item; - } - return { - ...item, - loading: true, - }; - }); - }, [conversationItems, pendingMessage, currentUser, isFunctionLoading]); - - const items = useMemo(() => { - const consolidatedChatItems: Array = []; - let currentGroup: ChatTimelineItem[] | null = null; - - for (const item of itemsWithAddedLoadingStates) { - if (item.display.hide || !item) continue; - - if (item.display.collapsed) { - if (currentGroup) { - currentGroup.push(item); - } else { - currentGroup = [item]; - consolidatedChatItems.push(currentGroup); - } - } else { - consolidatedChatItems.push(item); - currentGroup = null; - } - } - - return consolidatedChatItems; - }, [itemsWithAddedLoadingStates]); - - useEffect(() => { - return () => { - subscription?.unsubscribe(); - }; - }, [subscription]); - - return { - items, - onEdit: async (item, newMessage) => { - const indexOf = flatten(items).indexOf(item); - const sliced = messages.slice(0, indexOf); - const nextMessages = await chat(sliced.concat(newMessage)); - onChatComplete(nextMessages); - }, - onFeedback: (item, feedback) => {}, - onRegenerate: (item) => { - const indexOf = flatten(items).indexOf(item); - - chat(messages.slice(0, indexOf)).then((nextMessages) => onChatComplete(nextMessages)); - }, - onStopGenerating: () => { - subscription?.unsubscribe(); - setPendingMessage((prevPendingMessage) => ({ - message: { - role: MessageRole.Assistant, - ...prevPendingMessage?.message, - }, - aborted: true, - error: new AbortError(), - })); - setSubscription(undefined); - }, - onSubmit: async (message) => { - const nextMessages = await chat(messages.concat(message)); - onChatComplete(nextMessages); - }, - onActionClick: async (payload) => { - switch (payload.type) { - case ChatActionClickType.executeEsqlQuery: - const nextMessages = await chat( - messages.concat({ - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.Assistant, - content: '', - function_call: { - name: 'execute_query', - arguments: JSON.stringify({ - query: payload.query, - }), - trigger: MessageRole.User, - }, - }, - }) - ); - onChatComplete(nextMessages); - break; - } - }, - }; -} diff --git a/x-pack/plugins/observability_ai_assistant/public/routes/config.tsx b/x-pack/plugins/observability_ai_assistant/public/routes/config.tsx index f4245cb69d9e7b..ed0ac18302cc5f 100644 --- a/x-pack/plugins/observability_ai_assistant/public/routes/config.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/routes/config.tsx @@ -31,11 +31,18 @@ const observabilityAIAssistantRoutes = { element: , }, '/conversations/{conversationId}': { - params: t.type({ - path: t.type({ - conversationId: t.string, + params: t.intersection([ + t.type({ + path: t.type({ + conversationId: t.string, + }), }), - }), + t.partial({ + state: t.partial({ + prevConversationKey: t.string, + }), + }), + ]), element: , }, '/conversations': { diff --git a/x-pack/plugins/observability_ai_assistant/public/routes/conversations/conversation_view.tsx b/x-pack/plugins/observability_ai_assistant/public/routes/conversations/conversation_view.tsx index 5af2e740deb9b3..0b6c1cba60523b 100644 --- a/x-pack/plugins/observability_ai_assistant/public/routes/conversations/conversation_view.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/routes/conversations/conversation_view.tsx @@ -4,17 +4,18 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import React, { useMemo, useState } from 'react'; -import { EuiCallOut, EuiFlexGroup, EuiFlexItem, EuiLoadingSpinner, EuiSpacer } from '@elastic/eui'; +import { EuiFlexGroup, EuiFlexItem, EuiLoadingSpinner, EuiSpacer } from '@elastic/eui'; import { css } from '@emotion/css'; import { i18n } from '@kbn/i18n'; import { euiThemeVars } from '@kbn/ui-theme'; +import React, { useMemo, useRef, useState } from 'react'; +import usePrevious from 'react-use/lib/usePrevious'; +import { v4 } from 'uuid'; import { ChatBody } from '../../components/chat/chat_body'; import { ConversationList } from '../../components/chat/conversation_list'; import { ObservabilityAIAssistantChatServiceProvider } from '../../context/observability_ai_assistant_chat_service_provider'; import { useAbortableAsync } from '../../hooks/use_abortable_async'; import { useConfirmModal } from '../../hooks/use_confirm_modal'; -import { useConversation } from '../../hooks/use_conversation'; import { useCurrentUser } from '../../hooks/use_current_user'; import { useGenAIConnectors } from '../../hooks/use_genai_connectors'; import { useKibana } from '../../hooks/use_kibana'; @@ -22,18 +23,14 @@ import { useKnowledgeBase } from '../../hooks/use_knowledge_base'; import { useObservabilityAIAssistant } from '../../hooks/use_observability_ai_assistant'; import { useObservabilityAIAssistantParams } from '../../hooks/use_observability_ai_assistant_params'; import { useObservabilityAIAssistantRouter } from '../../hooks/use_observability_ai_assistant_router'; +import { EMPTY_CONVERSATION_TITLE } from '../../i18n'; import { getConnectorsManagementHref } from '../../utils/get_connectors_management_href'; import { getModelsManagementHref } from '../../utils/get_models_management_href'; -import { EMPTY_CONVERSATION_TITLE } from '../../i18n'; const containerClassName = css` max-width: 100%; `; -const chatBodyContainerClassNameWithError = css` - align-self: center; -`; - const conversationListContainerName = css` min-width: 250px; width: 250px; @@ -80,12 +77,22 @@ export function ConversationView() { const conversationId = 'conversationId' in path ? path.conversationId : undefined; - const { conversation, displayedMessages, setDisplayedMessages, save, saveTitle } = - useConversation({ - conversationId, - chatService: chatService.value, - connectorId: connectors.selectedConnector, - }); + // Regenerate the key only when the id changes, except after + // creating the conversation. Ideally this happens by adding + // state to the current route, but I'm not keen on adding + // the concept of state to the router, due to a mismatch + // between router.link() and router.push(). So, this is a + // pretty gross workaround for persisting a key under some + // conditions. + const chatBodyKeyRef = useRef(v4()); + const keepPreviousKeyRef = useRef(false); + const prevConversationId = usePrevious(conversationId); + + if (conversationId !== prevConversationId && keepPreviousKeyRef.current === false) { + chatBodyKeyRef.current = v4(); + } + + keepPreviousKeyRef.current = false; const conversations = useAbortableAsync( ({ signal }) => { @@ -111,14 +118,17 @@ export function ConversationView() { ]; }, [conversations.value?.conversations, conversationId, observabilityAIAssistantRouter]); - function navigateToConversation(nextConversationId?: string) { - observabilityAIAssistantRouter.push( - nextConversationId ? '/conversations/{conversationId}' : '/conversations/new', - { - path: { conversationId: nextConversationId }, + function navigateToConversation(nextConversationId?: string, usePrevConversationKey?: boolean) { + if (nextConversationId) { + observabilityAIAssistantRouter.push('/conversations/{conversationId}', { + path: { + conversationId: nextConversationId, + }, query: {}, - } - ); + }); + } else { + observabilityAIAssistantRouter.push('/conversations/new', { path: {}, query: {} }); + } } function handleRefreshConversations() { @@ -194,69 +204,36 @@ export function ConversationView() { /> - - {conversation.error ? ( - + + + + + + ) : null} + {chatService.value && ( + + { + if (!conversationId) { + keepPreviousKeyRef.current = true; + navigateToConversation(conversation.conversation.id); } - )} - iconType="warning" - > - {i18n.translate('xpack.observabilityAiAssistant.couldNotFindConversationContent', { - defaultMessage: - 'Could not find a conversation with id {conversationId}. Make sure the conversation exists and you have access to it.', - values: { conversationId }, - })} - - ) : null} - {!chatService.value ? ( - - - - - - - ) : null} - {conversation.value && chatService.value && !conversation.error ? ( - - { - setDisplayedMessages(messages); - }} - onChatComplete={(messages) => { - save(messages, handleRefreshConversations) - .then((nextConversation) => { - conversations.refresh(); - if (!conversationId && nextConversation?.conversation?.id) { - navigateToConversation(nextConversation.conversation.id); - } - }) - .catch((e) => {}); - }} - onSaveTitle={(title) => { - saveTitle(title, handleRefreshConversations); - }} - /> - - ) : null} - + handleRefreshConversations(); + }} + /> + + )} ); diff --git a/x-pack/plugins/observability_ai_assistant/public/service/create_mock_chat_service.ts b/x-pack/plugins/observability_ai_assistant/public/service/create_mock_chat_service.ts new file mode 100644 index 00000000000000..e255aa830467e4 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/public/service/create_mock_chat_service.ts @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { DeeplyMockedKeys } from '@kbn/utility-types-jest'; +import type { ObservabilityAIAssistantChatService } from '../types'; + +type MockedChatService = DeeplyMockedKeys; + +export const createMockChatService = (): MockedChatService => { + const mockChatService: MockedChatService = { + chat: jest.fn(), + executeFunction: jest.fn(), + getContexts: jest.fn().mockReturnValue([{ name: 'core', description: '' }]), + getFunctions: jest.fn().mockReturnValue([]), + hasFunction: jest.fn().mockReturnValue(false), + hasRenderFunction: jest.fn().mockReturnValue(true), + renderFunction: jest.fn(), + }; + return mockChatService; +}; diff --git a/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts b/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts index ed318397de73c4..6f2d1e5c2f0905 100644 --- a/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts +++ b/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts @@ -6,106 +6,98 @@ */ import { merge, uniqueId } from 'lodash'; -import { MessageRole, Conversation, FunctionDefinition } from '../../common/types'; -import { ChatTimelineItem } from '../components/chat/chat_timeline'; +import { DeepPartial } from 'utility-types'; +import { MessageRole, Conversation, FunctionDefinition, Message } from '../../common/types'; import { getAssistantSetupMessage } from '../service/get_assistant_setup_message'; -type ChatItemBuildProps = Omit, 'actions' | 'display' | 'currentUser'> & { - actions?: Partial; - display?: Partial; - currentUser?: Partial; -} & Pick; +type BuildMessageProps = DeepPartial & { + message: { + role: MessageRole; + function_call?: { + name: string; + trigger: MessageRole.Assistant | MessageRole.User | MessageRole.Elastic; + }; + }; +}; -export function buildChatItem(params: ChatItemBuildProps): ChatTimelineItem { +export function buildMessage(params: BuildMessageProps): Message { return merge( { - id: uniqueId(), - title: '', - actions: { - canCopy: true, - canEdit: false, - canGiveFeedback: false, - canRegenerate: params.role === MessageRole.Assistant, - }, - display: { - collapsed: false, - hide: false, - }, - currentUser: { - username: 'elastic', - }, - loading: false, + '@timestamp': new Date().toISOString(), }, params ); } -export function buildSystemChatItem(params?: Omit) { - return buildChatItem({ - role: MessageRole.System, - ...params, - }); -} - -export function buildChatInitItem() { - return buildChatItem({ - role: MessageRole.User, - title: 'started a conversation', - actions: { - canEdit: false, - canCopy: true, - canGiveFeedback: false, - canRegenerate: false, - }, - }); -} - -export function buildUserChatItem(params?: Omit) { - return buildChatItem({ - role: MessageRole.User, - content: "What's a function?", - actions: { - canCopy: true, - canEdit: true, - canGiveFeedback: false, - canRegenerate: true, - }, - ...params, - }); +export function buildSystemMessage( + params?: Omit & { + message: DeepPartial>; + } +) { + return buildMessage( + merge({}, params, { + message: { role: MessageRole.System }, + }) + ); } -export function buildAssistantChatItem(params?: Omit) { - return buildChatItem({ - role: MessageRole.Assistant, - content: `In computer programming and mathematics, a function is a fundamental concept that represents a relationship between input values and output values. It takes one or more input values (also known as arguments or parameters) and processes them to produce a result, which is the output of the function. The input values are passed to the function, and the function performs a specific set of operations or calculations on those inputs to produce the desired output. - A function is often defined with a name, which serves as an identifier to call and use the function in the code. It can be thought of as a reusable block of code that can be executed whenever needed, and it helps in organizing code and making it more modular and maintainable.`, - actions: { - canCopy: true, - canEdit: false, - canRegenerate: true, - canGiveFeedback: true, - }, - ...params, - }); +export function buildUserMessage( + params?: Omit & { + message?: DeepPartial>; + } +) { + return buildMessage( + merge( + { + message: { + content: "What's a function?", + }, + }, + params, + { + message: { role: MessageRole.User }, + } + ) + ); } -export function buildFunctionChatItem(params: Omit) { - return buildChatItem({ - role: MessageRole.User, - title: 'executed a function', - function_call: { - name: 'leftpad', - arguments: '{ foo: "bar" }', - trigger: MessageRole.Assistant, - }, - ...params, - }); +export function buildAssistantMessage( + params?: Omit & { + message: DeepPartial>; + } +) { + return buildMessage( + merge( + { + message: { + content: `In computer programming and mathematics, a function is a fundamental concept that represents a relationship between input values and output values. It takes one or more input values (also known as arguments or parameters) and processes them to produce a result, which is the output of the function. The input values are passed to the function, and the function performs a specific set of operations or calculations on those inputs to produce the desired output. + A function is often defined with a name, which serves as an identifier to call and use the function in the code. It can be thought of as a reusable block of code that can be executed whenever needed, and it helps in organizing code and making it more modular and maintainable.`, + }, + }, + params, + { + message: { role: MessageRole.Assistant }, + } + ) + ); } -export function buildTimelineItems() { - return { - items: [buildSystemChatItem(), buildUserChatItem(), buildAssistantChatItem()], - }; +export function buildFunctionResponseMessage( + params?: Omit & { + message: DeepPartial>; + } +) { + return buildUserMessage( + merge( + {}, + { + message: { + name: 'leftpad', + }, + ...params, + } + ) + ); } export function buildConversation(params?: Partial) { diff --git a/x-pack/plugins/observability_ai_assistant/public/utils/get_timeline_items_from_conversation.test.tsx b/x-pack/plugins/observability_ai_assistant/public/utils/get_timeline_items_from_conversation.test.tsx new file mode 100644 index 00000000000000..f066b7a9db370c --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/public/utils/get_timeline_items_from_conversation.test.tsx @@ -0,0 +1,597 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import React from 'react'; +import { last, pick } from 'lodash'; +import { render } from '@testing-library/react'; +import { Message, MessageRole } from '../../common'; +import { createMockChatService } from '../service/create_mock_chat_service'; +import { getTimelineItemsfromConversation } from './get_timeline_items_from_conversation'; +import { __IntlProvider as IntlProvider } from '@kbn/i18n-react'; +import { ObservabilityAIAssistantChatServiceProvider } from '../context/observability_ai_assistant_chat_service_provider'; +import { ChatState } from '../hooks/use_chat'; + +const mockChatService = createMockChatService(); + +let items: ReturnType; + +describe('getTimelineItemsFromConversation', () => { + describe('returns an opening message only', () => { + items = getTimelineItemsfromConversation({ + chatService: mockChatService, + hasConnector: true, + messages: [], + chatState: ChatState.Ready, + }); + + expect(items.length).toBe(1); + expect(items[0].title).toBe('started a conversation'); + }); + + describe('with a start of a conversation', () => { + beforeEach(() => { + items = getTimelineItemsfromConversation({ + chatService: mockChatService, + hasConnector: true, + currentUser: { + username: 'johndoe', + full_name: 'John Doe', + }, + chatState: ChatState.Ready, + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: 'System', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'User', + }, + }, + ], + }); + }); + it('excludes the system message', () => { + expect(items.length).toBe(2); + expect(items[0].title).toBe('started a conversation'); + }); + + it('includes the rest of the conversation', () => { + expect(items[1].currentUser?.full_name).toEqual('John Doe'); + expect(items[1].content).toEqual('User'); + }); + + it('formats the user message', () => { + expect(pick(items[1], 'title', 'actions', 'display', 'loading')).toEqual({ + title: '', + actions: { + canCopy: true, + canEdit: true, + canGiveFeedback: false, + canRegenerate: false, + }, + display: { + collapsed: false, + hide: false, + }, + loading: false, + }); + }); + }); + + describe('with function calling', () => { + beforeEach(() => { + mockChatService.hasRenderFunction.mockImplementation(() => false); + items = getTimelineItemsfromConversation({ + chatService: mockChatService, + hasConnector: true, + chatState: ChatState.Ready, + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: 'System', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'Hello', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + function_call: { + name: 'recall', + arguments: JSON.stringify({ queries: [], contexts: [] }), + trigger: MessageRole.Assistant, + }, + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + name: 'recall', + content: JSON.stringify([]), + }, + }, + ], + }); + }); + + it('formats the function request', () => { + expect(pick(items[2], 'actions', 'display', 'loading')).toEqual({ + actions: { + canCopy: true, + canEdit: true, + canGiveFeedback: false, + canRegenerate: true, + }, + display: { + collapsed: true, + hide: false, + }, + loading: false, + }); + + const { container } = render(items[2].title as React.ReactElement, { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + expect(container.textContent).toBe('requested the function recall'); + }); + + it('formats the function response', () => { + expect(pick(items[3], 'actions', 'display', 'loading')).toEqual({ + actions: { + canCopy: true, + canEdit: false, + canGiveFeedback: false, + canRegenerate: false, + }, + display: { + collapsed: true, + hide: false, + }, + loading: false, + }); + + const { container } = render(items[3].title as React.ReactElement, { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + expect(container.textContent).toBe('executed the function recall'); + }); + }); + describe('with a render function', () => { + beforeEach(() => { + mockChatService.hasRenderFunction.mockImplementation(() => true); + mockChatService.renderFunction.mockImplementation(() => 'Rendered'); + items = getTimelineItemsfromConversation({ + chatService: mockChatService, + hasConnector: true, + chatState: ChatState.Ready, + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: 'System', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'Hello', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + function_call: { + name: 'my_render_function', + arguments: JSON.stringify({ foo: 'bar' }), + trigger: MessageRole.Assistant, + }, + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + name: 'my_render_function', + content: JSON.stringify([]), + }, + }, + ], + }); + }); + + it('renders a display element', () => { + expect(mockChatService.hasRenderFunction).toHaveBeenCalledWith('my_render_function'); + + expect(pick(items[3], 'actions', 'display')).toEqual({ + actions: { + canCopy: true, + canEdit: false, + canGiveFeedback: false, + canRegenerate: false, + }, + display: { + collapsed: false, + hide: false, + }, + }); + + expect(items[3].element).toBeTruthy(); + + const { container } = render(items[3].element as React.ReactElement, { + wrapper: ({ children }) => ( + + + {children} + + + ), + }); + + expect(mockChatService.renderFunction).toHaveBeenCalledWith( + 'my_render_function', + JSON.stringify({ foo: 'bar' }), + { content: '[]', name: 'my_render_function', role: 'user' } + ); + + expect(container.textContent).toEqual('Rendered'); + }); + }); + + describe('with a function that errors out', () => { + beforeEach(() => { + items = getTimelineItemsfromConversation({ + chatService: mockChatService, + hasConnector: true, + chatState: ChatState.Ready, + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: 'System', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'Hello', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + function_call: { + name: 'my_render_function', + arguments: JSON.stringify({ foo: 'bar' }), + trigger: MessageRole.Assistant, + }, + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + name: 'my_render_function', + content: JSON.stringify({ + error: { + message: 'An error occurred', + }, + }), + }, + }, + ], + }); + }); + + it('returns a title that reflects a failure to execute the function', () => { + const { container } = render(items[3].title as React.ReactElement, { + wrapper: ({ children }) => ( + + {children} + + ), + }); + + expect(container.textContent).toBe('failed to execute the function my_render_function'); + }); + + it('formats the messages correctly', () => { + expect(pick(items[3], 'actions', 'display', 'loading')).toEqual({ + actions: { + canCopy: true, + canEdit: false, + canGiveFeedback: false, + canRegenerate: false, + }, + display: { + collapsed: true, + hide: false, + }, + loading: false, + }); + }); + }); + + describe('with an invalid JSON response', () => { + beforeEach(() => { + items = getTimelineItemsfromConversation({ + chatService: mockChatService, + hasConnector: true, + currentUser: { + username: 'johndoe', + full_name: 'John Doe', + }, + chatState: ChatState.Ready, + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: 'System', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + content: '', + function_call: { + name: 'my_function', + arguments: JSON.stringify({}), + trigger: MessageRole.User, + }, + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'invalid-json', + name: 'my_function', + }, + }, + ], + }); + }); + + it('sets the invalid json as content', () => { + expect(items[2].content).toBe( + `\`\`\` +{ + "content": "invalid-json" +} +\`\`\`` + ); + }); + }); + + describe('when starting from a contextual insight', () => { + beforeEach(() => { + items = getTimelineItemsfromConversation({ + chatService: mockChatService, + hasConnector: true, + currentUser: { + username: 'johndoe', + full_name: 'John Doe', + }, + chatState: ChatState.Ready, + startedFrom: 'contextualInsight', + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: 'System', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'Test', + }, + }, + ], + }); + }); + + it('hides the first user message', () => { + expect(items[1].display.collapsed).toBe(true); + }); + }); + + describe('with function calling suggested by the user', () => { + beforeEach(() => { + mockChatService.hasRenderFunction.mockImplementation(() => false); + items = getTimelineItemsfromConversation({ + chatService: mockChatService, + hasConnector: true, + chatState: ChatState.Ready, + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: 'System', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + function_call: { + name: 'recall', + arguments: JSON.stringify({ queries: [], contexts: [] }), + trigger: MessageRole.User, + }, + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + name: 'recall', + content: JSON.stringify([]), + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + content: 'Reply from assistant', + }, + }, + ], + }); + }); + + it('formats the function request', () => { + expect(pick(items[1], 'actions', 'display')).toEqual({ + actions: { + canCopy: true, + canRegenerate: false, + canEdit: true, + canGiveFeedback: false, + }, + display: { + collapsed: true, + hide: false, + }, + }); + }); + + it('formats the assistant response', () => { + expect(pick(items[3], 'actions', 'display')).toEqual({ + actions: { + canCopy: true, + canRegenerate: true, + canEdit: false, + canGiveFeedback: false, + }, + display: { + collapsed: false, + hide: false, + }, + }); + }); + }); + + describe('while the chat is loading', () => { + const renderWithLoading = (extraMessages: Message[]) => { + items = getTimelineItemsfromConversation({ + chatService: mockChatService, + hasConnector: true, + chatState: ChatState.Loading, + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: 'System', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'Test', + }, + }, + ...extraMessages, + ], + }); + }; + + describe('with a user message last', () => { + beforeEach(() => { + renderWithLoading([]); + }); + + it('adds an assistant message which is loading', () => { + expect(pick(last(items), 'display', 'actions', 'loading', 'role', 'content')).toEqual({ + loading: true, + role: MessageRole.Assistant, + actions: { + canCopy: false, + canRegenerate: false, + canEdit: false, + canGiveFeedback: false, + }, + display: { + collapsed: false, + hide: false, + }, + content: '', + }); + }); + }); + + describe('with a function request as the last message', () => { + beforeEach(() => { + renderWithLoading([ + { + '@timestamp': new Date().toISOString(), + message: { + function_call: { + name: 'my_function_call', + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }, + ]); + }); + + it('adds an assistant message which is loading', () => { + expect(pick(last(items), 'display', 'actions', 'loading', 'role', 'content')).toEqual({ + loading: true, + role: MessageRole.Assistant, + actions: { + canCopy: false, + canRegenerate: false, + canEdit: false, + canGiveFeedback: false, + }, + display: { + collapsed: false, + hide: false, + }, + content: '', + }); + }); + }); + }); +}); diff --git a/x-pack/plugins/observability_ai_assistant/public/utils/get_timeline_items_from_conversation.tsx b/x-pack/plugins/observability_ai_assistant/public/utils/get_timeline_items_from_conversation.tsx index 61f2a4fcbd383b..3bf54ec628f244 100644 --- a/x-pack/plugins/observability_ai_assistant/public/utils/get_timeline_items_from_conversation.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/utils/get_timeline_items_from_conversation.tsx @@ -6,7 +6,7 @@ */ import React from 'react'; import { v4 } from 'uuid'; -import { isEmpty, omitBy } from 'lodash'; +import { isEmpty, last, omitBy } from 'lodash'; import { useEuiTheme } from '@elastic/eui'; import { i18n } from '@kbn/i18n'; import { FormattedMessage } from '@kbn/i18n-react'; @@ -15,6 +15,15 @@ import { Message, MessageRole } from '../../common'; import type { ChatTimelineItem } from '../components/chat/chat_timeline'; import { RenderFunction } from '../components/render_function'; import type { ObservabilityAIAssistantChatService } from '../types'; +import { ChatState } from '../hooks/use_chat'; + +function safeParse(jsonStr: string) { + try { + return JSON.parse(jsonStr); + } catch (err) { + return jsonStr; + } +} function convertMessageToMarkdownCodeBlock(message: Message['message']) { let value: object; @@ -22,7 +31,7 @@ function convertMessageToMarkdownCodeBlock(message: Message['message']) { if (!message.name) { const name = message.function_call?.name; const args = message.function_call?.arguments - ? JSON.parse(message.function_call.arguments) + ? safeParse(message.function_call.arguments) : undefined; value = { @@ -32,9 +41,9 @@ function convertMessageToMarkdownCodeBlock(message: Message['message']) { } else { const content = message.role !== MessageRole.Assistant && message.content - ? JSON.parse(message.content) + ? safeParse(message.content) : message.content; - const data = message.data ? JSON.parse(message.data) : undefined; + const data = message.data ? safeParse(message.data) : undefined; value = omitBy( { content, @@ -61,26 +70,36 @@ export function getTimelineItemsfromConversation({ hasConnector, messages, startedFrom, + chatState, }: { chatService: ObservabilityAIAssistantChatService; currentUser?: Pick; hasConnector: boolean; messages: Message[]; startedFrom?: StartedFrom; + chatState: ChatState; }): ChatTimelineItem[] { - return [ + const messagesWithoutSystem = messages.filter( + (message) => message.message.role !== MessageRole.System + ); + + const items: ChatTimelineItem[] = [ { id: v4(), actions: { canCopy: false, canEdit: false, canGiveFeedback: false, canRegenerate: false }, display: { collapsed: false, hide: false }, currentUser, loading: false, - role: MessageRole.User, + message: { + '@timestamp': new Date().toISOString(), + message: { role: MessageRole.User }, + }, title: i18n.translate('xpack.observabilityAiAssistant.conversationStartTitle', { defaultMessage: 'started a conversation', }), + role: MessageRole.User, }, - ...messages.map((message, index) => { + ...messagesWithoutSystem.map((message, index) => { const id = v4(); let title: React.ReactNode = ''; @@ -88,8 +107,10 @@ export function getTimelineItemsfromConversation({ let element: React.ReactNode | undefined; const prevFunctionCall = - message.message.name && messages[index - 1] && messages[index - 1].message.function_call - ? messages[index - 1].message.function_call + message.message.name && + messagesWithoutSystem[index - 1] && + messagesWithoutSystem[index - 1].message.function_call + ? messagesWithoutSystem[index - 1].message.function_call : undefined; let role = message.message.function_call?.trigger || message.message.role; @@ -107,10 +128,6 @@ export function getTimelineItemsfromConversation({ }; switch (role) { - case MessageRole.System: - display.hide = true; - break; - case MessageRole.User: actions.canCopy = true; actions.canGiveFeedback = false; @@ -120,16 +137,15 @@ export function getTimelineItemsfromConversation({ // User executed a function: if (message.message.name && prevFunctionCall) { - let parsedContent; + let isError = false; try { - parsedContent = JSON.parse(message.message.content ?? 'null'); + const parsedContent = JSON.parse(message.message.content ?? 'null'); + isError = + parsedContent && typeof parsedContent === 'object' && 'error' in parsedContent; } catch (error) { - parsedContent = message.message.content; + isError = true; } - const isError = - parsedContent && typeof parsedContent === 'object' && 'error' in parsedContent; - title = !isError ? ( el.message.role === MessageRole.User ); @@ -252,7 +268,53 @@ export function getTimelineItemsfromConversation({ currentUser, function_call: message.message.function_call, loading: false, + message, }; }), ]; + + const isLoading = chatState === ChatState.Loading; + + let lastMessage = last(items); + + const isNaturalLanguageOnlyAnswerFromAssistant = + lastMessage?.message.message.role === MessageRole.Assistant && + !lastMessage.message.message.function_call?.name; + + const addLoadingPlaceholder = isLoading && !isNaturalLanguageOnlyAnswerFromAssistant; + + if (addLoadingPlaceholder) { + items.push({ + id: v4(), + actions: { + canCopy: false, + canEdit: false, + canGiveFeedback: false, + canRegenerate: false, + }, + display: { + collapsed: false, + hide: false, + }, + content: '', + currentUser, + loading: chatState === ChatState.Loading, + role: MessageRole.Assistant, + title: '', + message: { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + content: '', + }, + }, + }); + lastMessage = last(items); + } + + if (isLoading && lastMessage) { + lastMessage.loading = isLoading; + } + + return items; } From a246f9760f7bc4b1d00abf683490dd8dee5fea39 Mon Sep 17 00:00:00 2001 From: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Date: Mon, 27 Nov 2023 10:11:03 +0000 Subject: [PATCH 03/15] [CI] Auto-commit changed files from 'node scripts/lint_ts_projects --fix' --- x-pack/plugins/observability_ai_assistant/tsconfig.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugins/observability_ai_assistant/tsconfig.json b/x-pack/plugins/observability_ai_assistant/tsconfig.json index afdc9a4a892436..93817dcf791962 100644 --- a/x-pack/plugins/observability_ai_assistant/tsconfig.json +++ b/x-pack/plugins/observability_ai_assistant/tsconfig.json @@ -46,7 +46,8 @@ "@kbn/es-query", "@kbn/rule-registry-plugin", "@kbn/licensing-plugin", - "@kbn/share-plugin" + "@kbn/share-plugin", + "@kbn/utility-types-jest" ], "exclude": ["target/**/*"] } From b593079c6a0e02efb221f61d8ed26ff278e7c91d Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Mon, 27 Nov 2023 14:41:35 +0100 Subject: [PATCH 04/15] Fix failing test --- .../public/components/chat/chat_body.tsx | 27 ++++++ .../public/hooks/use_conversation.test.tsx | 96 ++++++++++--------- .../public/hooks/use_conversation.ts | 8 +- 3 files changed, 82 insertions(+), 49 deletions(-) diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx index 88990b3a9e47ba..838ca53219c360 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx @@ -259,6 +259,33 @@ export function ChatBody({ ); } + if (conversation.error) { + return ( + + + + {i18n.translate('xpack.observabilityAiAssistant.couldNotFindConversationContent', { + defaultMessage: + 'Could not find a conversation with id {conversationId}. Make sure the conversation exists and you have access to it.', + values: { conversationId: initialConversationId }, + })} + + + + ); + } + return ( {connectors.selectedConnector ? ( diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx index b9e837962ee1ba..1ea3ab21b0f6ad 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx @@ -25,7 +25,7 @@ import { ChatState } from './use_chat'; import { createMockChatService } from '../service/create_mock_chat_service'; import { Subject } from 'rxjs'; import { EMPTY_CONVERSATION_TITLE } from '../i18n'; -import { omit } from 'lodash'; +import { merge, omit } from 'lodash'; let hookResult: RenderHookResult; @@ -271,13 +271,55 @@ describe('useConversation', () => { describe('when chat completes without an initial conversation id', () => { const subject: Subject = new Subject(); - beforeEach(() => { - mockService.callApi.mockImplementation(async (endpoint, request) => ({ - conversation: { - id: 'my-conversation-id', + const expectedMessages = [ + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.System, + content: '', }, - messages: (request as any).params.body.messages, - })); + }, + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.User, + content: 'Hello', + }, + }, + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.Assistant, + content: 'Goodbye', + }, + }, + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.User, + content: 'Hello again', + }, + }, + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.Assistant, + content: 'Goodbye again', + }, + }, + ]; + beforeEach(() => { + mockService.callApi.mockImplementation(async (endpoint, request) => + merge( + { + conversation: { + id: 'my-conversation-id', + }, + messages: expectedMessages, + }, + (request as any).params.body + ) + ); hookResult = renderHook(useConversation, { initialProps: { @@ -309,44 +351,6 @@ describe('useConversation', () => { }); it('the conversation is created including the initial messages', async () => { - const expectedMessages = [ - { - '@timestamp': expect.any(String), - message: { - role: MessageRole.System, - content: '', - }, - }, - { - '@timestamp': expect.any(String), - message: { - role: MessageRole.User, - content: 'Hello', - }, - }, - { - '@timestamp': expect.any(String), - message: { - role: MessageRole.Assistant, - content: 'Goodbye', - }, - }, - { - '@timestamp': expect.any(String), - message: { - role: MessageRole.User, - content: 'Hello again', - }, - }, - { - '@timestamp': expect.any(String), - message: { - role: MessageRole.Assistant, - content: 'Goodbye again', - }, - }, - ]; - act(() => { hookResult.result.current.next( hookResult.result.current.messages.concat({ @@ -389,6 +393,8 @@ describe('useConversation', () => { }, ]); + expect(hookResult.result.current.conversation.error).toBeUndefined(); + expect(hookResult.result.current.messages).toEqual(expectedMessages); }); }); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts index 8f8581874c27ba..133a8fd40d30b1 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts @@ -109,10 +109,10 @@ export function useConversation({ }); return ( - initialConversationId + displayedConversationId ? update( merge( - { conversation: { id: initialConversationId } }, + { conversation: { id: displayedConversationId } }, nextConversationObject ) as Conversation ) @@ -197,10 +197,10 @@ export function useConversation({ throw error; }); }, - [initialConversationId, initialTitle], + [displayedConversationId, initialTitle], { defaultValue: () => { - if (!initialConversationId) { + if (!displayedConversationId) { const nextConversation = createNewConversation({ title: initialTitle }); return nextConversation; } From 3d39b809071d27ee94fff5a1e479a70eecc17e4a Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Tue, 28 Nov 2023 08:03:02 +0100 Subject: [PATCH 05/15] Remove unused files --- .../common/chat/streaming.ts | 74 ---------- .../common/errors/index.ts | 23 --- .../common/index.ts | 4 - .../public/hooks/use_conversation.test.tsx | 135 +++++++++++++++++- .../public/hooks/use_conversation.ts | 16 ++- 5 files changed, 149 insertions(+), 103 deletions(-) delete mode 100644 x-pack/plugins/observability_ai_assistant/common/chat/streaming.ts delete mode 100644 x-pack/plugins/observability_ai_assistant/common/errors/index.ts diff --git a/x-pack/plugins/observability_ai_assistant/common/chat/streaming.ts b/x-pack/plugins/observability_ai_assistant/common/chat/streaming.ts deleted file mode 100644 index 6ee1751064f09c..00000000000000 --- a/x-pack/plugins/observability_ai_assistant/common/chat/streaming.ts +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { ChatCompletionError as ChatCompletionErrorClass } from '../errors'; -import { Message } from '../types'; - -export enum StreamingChatResponseEventType { - ChatCompletionChunk = 'chatCompletionChunk', - ConversationCreate = 'conversationCreate', - ConversationUpdate = 'conversationUpdate', - MessageAdd = 'messageAdd', - ChatCompletionError = 'chatCompletionError', -} - -type StreamingChatResponseEventBase< - TEventType extends StreamingChatResponseEventType, - TData extends {} -> = { - type: TEventType; -} & TData; - -type ChatCompletionChunkEvent = StreamingChatResponseEventBase< - StreamingChatResponseEventType.ChatCompletionChunk, - { - message: { - content?: string; - function_call?: { - name?: string; - args?: string; - }; - }; - } ->; - -type ConversationCreateEvent = StreamingChatResponseEventBase< - StreamingChatResponseEventType.ConversationCreate, - { - conversation: { - id: string; - }; - } ->; - -type ConversationUpdateEvent = StreamingChatResponseEventBase< - StreamingChatResponseEventType.ConversationUpdate, - { - conversation: { - id: string; - title: string; - last_updated: string; - }; - } ->; - -type MessageAddEvent = StreamingChatResponseEventBase< - StreamingChatResponseEventType.MessageAdd, - Message ->; - -type ChatCompletionErrorEvent = StreamingChatResponseEventBase< - StreamingChatResponseEventType.ChatCompletionError, - typeof ChatCompletionErrorClass ->; - -export type StreamingChatResponseEvent = - | ChatCompletionChunkEvent - | ConversationCreateEvent - | ConversationUpdateEvent - | MessageAddEvent - | ChatCompletionErrorEvent; diff --git a/x-pack/plugins/observability_ai_assistant/common/errors/index.ts b/x-pack/plugins/observability_ai_assistant/common/errors/index.ts deleted file mode 100644 index 6625b0c9cf9a12..00000000000000 --- a/x-pack/plugins/observability_ai_assistant/common/errors/index.ts +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -export enum ChatCompletionErrorCode { - InternalError = 'internalError', -} - -export class ChatCompletionError extends Error { - code: ChatCompletionErrorCode; - - constructor(code: ChatCompletionErrorCode, message: string) { - super(message); - this.code = code; - } -} - -export function isChatCompletionError(error: Error): error is ChatCompletionError { - return error instanceof ChatCompletionError; -} diff --git a/x-pack/plugins/observability_ai_assistant/common/index.ts b/x-pack/plugins/observability_ai_assistant/common/index.ts index a4f181a127cfc8..92cd91871da696 100644 --- a/x-pack/plugins/observability_ai_assistant/common/index.ts +++ b/x-pack/plugins/observability_ai_assistant/common/index.ts @@ -7,7 +7,3 @@ export type { Message, Conversation } from './types'; export { MessageRole } from './types'; - -export { type StreamingChatResponseEvent, StreamingChatResponseEventType } from './chat/streaming'; - -export { ChatCompletionError, ChatCompletionErrorCode } from './errors'; diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx index 1ea3ab21b0f6ad..11d5f322ddbcfa 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx @@ -537,6 +537,139 @@ describe('useConversation', () => { }); describe('when the title is updated', () => { - it('the conversation is saved with the updated title', () => {}); + describe('without a stored conversation', () => { + beforeEach(() => { + hookResult = renderHook(useConversation, { + initialProps: { + chatService: mockChatService, + connectorId: 'my-connector', + initialMessages: [ + { + '@timestamp': new Date().toISOString(), + message: { content: '', role: MessageRole.User }, + }, + ], + initialConversationId: 'foo', + }, + wrapper, + }); + }); + + it('throws an error', () => { + expect(() => hookResult.result.current.saveTitle('my-new-title')).toThrow(); + }); + }); + + describe('with a stored conversation', () => { + let resolve: (value: unknown) => void; + beforeEach(async () => { + mockService.callApi.mockImplementation(async (endpoint, request) => { + if ( + endpoint === 'PUT /internal/observability_ai_assistant/conversation/{conversationId}' + ) { + return new Promise((_resolve) => { + resolve = _resolve; + }); + } + return { + '@timestamp': new Date().toISOString(), + conversation: { + id: 'my-conversation-id', + title: EMPTY_CONVERSATION_TITLE, + }, + labels: {}, + numeric_labels: {}, + public: false, + messages: [], + }; + }); + + await act(async () => { + hookResult = renderHook(useConversation, { + initialProps: { + chatService: mockChatService, + connectorId: 'my-connector', + initialConversationId: 'my-conversation-id', + }, + wrapper, + }); + }); + }); + + it('does not throw an error', () => { + expect(() => hookResult.result.current.saveTitle('my-new-title')).not.toThrow(); + }); + + it('calls the update API', async () => { + act(() => { + hookResult.result.current.saveTitle('my-new-title'); + }); + + expect(resolve).not.toBeUndefined(); + + expect(mockService.callApi.mock.calls[1]).toEqual([ + 'PUT /internal/observability_ai_assistant/conversation/{conversationId}', + { + signal: null, + params: { + path: { + conversationId: 'my-conversation-id', + }, + body: { + conversation: { + '@timestamp': expect.any(String), + conversation: { + title: 'my-new-title', + id: 'my-conversation-id', + }, + labels: expect.anything(), + messages: expect.anything(), + numeric_labels: expect.anything(), + public: expect.anything(), + }, + }, + }, + }, + ]); + + mockService.callApi.mockImplementation(async (endpoint, request) => { + return { + '@timestamp': new Date().toISOString(), + conversation: { + id: 'my-conversation-id', + title: 'my-new-title', + }, + labels: {}, + numeric_labels: {}, + public: false, + messages: [], + }; + }); + + await act(async () => { + resolve({ + conversation: { + title: 'my-new-title', + }, + }); + }); + + expect(mockService.callApi.mock.calls[2]).toEqual([ + 'GET /internal/observability_ai_assistant/conversation/{conversationId}', + { + signal: expect.anything(), + params: { + path: { + conversationId: 'my-conversation-id', + }, + }, + }, + ]); + + expect(hookResult.result.current.conversation.value?.conversation.title).toBe( + 'my-new-title' + ); + }); + }); }); }); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts index 133a8fd40d30b1..c753f7c7b19292 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts @@ -215,6 +215,20 @@ export function useConversation({ next, stop, messages, - saveTitle: () => {}, + saveTitle: (title: string) => { + if (!displayedConversationId || !conversation.value) { + throw new Error('Cannot save title if conversation is not stored'); + } + const nextConversation = merge({}, conversation.value as Conversation, { + conversation: { title }, + }); + return update(nextConversation) + .then(() => { + return conversation.refresh(); + }) + .then(() => { + onConversationUpdate?.(nextConversation); + }); + }, }; } From 1e1e7c08452f62d0ad2862e18e3d2199d97b0f4b Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Tue, 28 Nov 2023 12:25:46 +0100 Subject: [PATCH 06/15] Disable editing of title until conversation has been stored --- .../public/components/chat/chat_body.tsx | 12 ++++++++++-- .../public/components/chat/chat_header.tsx | 7 ++++++- .../public/components/chat/chat_timeline.tsx | 2 +- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx index 838ca53219c360..afbb042d2eadcc 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx @@ -171,7 +171,11 @@ export function ChatBody({ ); - } else if (connectors.loading || knowledgeBase.status.loading || conversation.loading) { + } else if ( + connectors.loading || + knowledgeBase.status.loading || + (!conversation.value && conversation.loading) + ) { footer = ( @@ -316,7 +320,11 @@ export function ChatBody({ ) : null} diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx index 065208695349d1..3c1f8935e00ea7 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx @@ -5,7 +5,7 @@ * 2.0. */ -import React, { ReactNode, useMemo } from 'react'; +import React, { ReactNode, useEffect, useMemo } from 'react'; import { css } from '@emotion/css'; import { EuiCommentList } from '@elastic/eui'; import type { AuthenticatedUser } from '@kbn/security-plugin/common'; From bfd2ef9418296825b92de87c302a3e49039836df Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Tue, 28 Nov 2023 12:32:54 +0100 Subject: [PATCH 07/15] Remove unused import --- .../public/components/chat/chat_timeline.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx index 3c1f8935e00ea7..065208695349d1 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_timeline.tsx @@ -5,7 +5,7 @@ * 2.0. */ -import React, { ReactNode, useEffect, useMemo } from 'react'; +import React, { ReactNode, useMemo } from 'react'; import { css } from '@emotion/css'; import { EuiCommentList } from '@elastic/eui'; import type { AuthenticatedUser } from '@kbn/security-plugin/common'; From 02d8f6a71341cb4ceaa55f23520371897fdadf12 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Tue, 5 Dec 2023 16:01:37 +0100 Subject: [PATCH 08/15] [Obs AI Assistant] /complete endpoint --- .../get_apm_services_list.ts | 77 -- .../get_apm_timeseries.tsx | 416 +++--- .../apm/public/assistant_functions/index.ts | 145 +-- x-pack/plugins/apm/public/plugin.ts | 8 +- .../get_apm_correlations.ts | 21 +- .../get_apm_downstream_dependencies.ts | 24 +- .../get_apm_error_document.ts | 21 +- .../get_apm_service_summary.ts | 46 +- .../get_apm_services_list.ts | 137 ++ .../assistant_functions/get_apm_timeseries.ts | 174 +++ .../apm/server/assistant_functions/index.ts | 186 +++ x-pack/plugins/apm/server/plugin.ts | 17 +- .../routes/assistant_functions/route.ts | 130 -- x-pack/plugins/apm/server/types.ts | 8 +- .../utils/non_empty_string_ref.ts | 0 .../common/conversation_complete.ts | 109 ++ .../common/functions/lens.tsx | 122 ++ .../common/types.ts | 53 +- .../common/utils/concatenate_openai_chunks.ts | 34 + .../utils/filter_function_definitions.ts | 29 + .../common/utils/process_openai_stream.ts | 45 + .../observability_ai_assistant/jest.config.js | 3 + .../observability_ai_assistant/kibana.jsonc | 3 +- .../public/components/chat/chat_body.tsx | 2 +- .../components/chat/function_list_popover.tsx | 8 +- .../public/components/insight/insight.tsx | 1 + .../public/functions/alerts.ts | 86 -- .../public/functions/get_dataset_info.ts | 153 --- .../public/functions/index.ts | 85 +- .../public/functions/kibana.ts | 69 - .../public/functions/lens.tsx | 123 +- .../public/hooks/use_chat.test.ts | 309 +---- .../public/hooks/use_chat.ts | 296 +++-- .../public/hooks/use_conversation.test.tsx | 254 +--- .../public/hooks/use_conversation.ts | 72 +- .../public/hooks/use_json_editor_model.ts | 10 +- .../public/plugin.tsx | 7 +- .../service/create_chat_service.test.ts | 6 + .../public/service/create_chat_service.ts | 287 ++--- .../service/create_mock_chat_service.ts | 2 +- .../public/service/create_service.ts | 7 +- .../public/types.ts | 53 +- .../public/utils/builders.ts | 54 +- .../public/utils/create_initialized_object.ts | 2 +- .../public/utils/storybook_decorator.tsx | 30 +- .../server/functions/alerts.ts | 153 +++ .../functions/elasticsearch.ts | 34 +- .../{public => server}/functions/esql.ts | 156 +-- .../server/functions/get_dataset_info.ts | 191 +++ .../server/functions/index.ts | 86 ++ .../server/functions/lens.ts | 16 + .../{public => server}/functions/recall.ts | 25 +- .../{public => server}/functions/summarize.ts | 32 +- .../server/index.ts | 5 + .../server/plugin.ts | 7 +- .../server/routes/chat/route.ts | 63 + .../server/routes/conversations/route.ts | 32 - .../server/routes/functions/route.ts | 145 +-- .../server/routes/types.ts | 4 + .../service/chat_function_client/index.ts | 91 ++ .../service/client/handle_llm_response.ts | 71 ++ .../server/service/client/index.test.ts | 1127 +++++++++++++++++ .../server/service/client/index.ts | 312 ++++- .../server/service/index.ts | 117 +- .../server/service/types.ts | 48 + .../service/util/stream_into_observable.ts | 23 + .../server/types.ts | 5 +- 67 files changed, 4013 insertions(+), 2454 deletions(-) delete mode 100644 x-pack/plugins/apm/public/assistant_functions/get_apm_services_list.ts rename x-pack/plugins/apm/{public => server}/assistant_functions/get_apm_correlations.ts (91%) rename x-pack/plugins/apm/{public => server}/assistant_functions/get_apm_downstream_dependencies.ts (83%) rename x-pack/plugins/apm/{public => server}/assistant_functions/get_apm_error_document.ts (83%) rename x-pack/plugins/apm/{public => server}/assistant_functions/get_apm_service_summary.ts (64%) create mode 100644 x-pack/plugins/apm/server/assistant_functions/get_apm_services_list.ts create mode 100644 x-pack/plugins/apm/server/assistant_functions/get_apm_timeseries.ts create mode 100644 x-pack/plugins/apm/server/assistant_functions/index.ts rename x-pack/plugins/apm/{public => server}/utils/non_empty_string_ref.ts (100%) create mode 100644 x-pack/plugins/observability_ai_assistant/common/conversation_complete.ts create mode 100644 x-pack/plugins/observability_ai_assistant/common/functions/lens.tsx create mode 100644 x-pack/plugins/observability_ai_assistant/common/utils/concatenate_openai_chunks.ts create mode 100644 x-pack/plugins/observability_ai_assistant/common/utils/filter_function_definitions.ts create mode 100644 x-pack/plugins/observability_ai_assistant/common/utils/process_openai_stream.ts delete mode 100644 x-pack/plugins/observability_ai_assistant/public/functions/alerts.ts delete mode 100644 x-pack/plugins/observability_ai_assistant/public/functions/get_dataset_info.ts delete mode 100644 x-pack/plugins/observability_ai_assistant/public/functions/kibana.ts create mode 100644 x-pack/plugins/observability_ai_assistant/server/functions/alerts.ts rename x-pack/plugins/observability_ai_assistant/{public => server}/functions/elasticsearch.ts (65%) rename x-pack/plugins/observability_ai_assistant/{public => server}/functions/esql.ts (89%) create mode 100644 x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts create mode 100644 x-pack/plugins/observability_ai_assistant/server/functions/index.ts create mode 100644 x-pack/plugins/observability_ai_assistant/server/functions/lens.ts rename x-pack/plugins/observability_ai_assistant/{public => server}/functions/recall.ts (86%) rename x-pack/plugins/observability_ai_assistant/{public => server}/functions/summarize.ts (82%) create mode 100644 x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.ts create mode 100644 x-pack/plugins/observability_ai_assistant/server/service/client/handle_llm_response.ts create mode 100644 x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts create mode 100644 x-pack/plugins/observability_ai_assistant/server/service/util/stream_into_observable.ts diff --git a/x-pack/plugins/apm/public/assistant_functions/get_apm_services_list.ts b/x-pack/plugins/apm/public/assistant_functions/get_apm_services_list.ts deleted file mode 100644 index 9047f768449e52..00000000000000 --- a/x-pack/plugins/apm/public/assistant_functions/get_apm_services_list.ts +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { i18n } from '@kbn/i18n'; -import type { RegisterFunctionDefinition } from '@kbn/observability-ai-assistant-plugin/common/types'; -import { ServiceHealthStatus } from '../../common/service_health_status'; -import { callApmApi } from '../services/rest/create_call_apm_api'; -import { NON_EMPTY_STRING } from '../utils/non_empty_string_ref'; - -export function registerGetApmServicesListFunction({ - registerFunction, -}: { - registerFunction: RegisterFunctionDefinition; -}) { - registerFunction( - { - name: 'get_apm_services_list', - contexts: ['apm'], - description: `Gets a list of services`, - descriptionForUser: i18n.translate( - 'xpack.apm.observabilityAiAssistant.functions.registerGetApmServicesList.descriptionForUser', - { - defaultMessage: `Gets the list of monitored services, their health status, and alerts.`, - } - ), - parameters: { - type: 'object', - additionalProperties: false, - properties: { - 'service.environment': { - ...NON_EMPTY_STRING, - description: - 'Optionally filter the services by the environments that they are running in', - }, - start: { - ...NON_EMPTY_STRING, - description: - 'The start of the time range, in Elasticsearch date math, like `now`.', - }, - end: { - ...NON_EMPTY_STRING, - description: - 'The end of the time range, in Elasticsearch date math, like `now-24h`.', - }, - healthStatus: { - type: 'array', - description: 'Filter service list by health status', - additionalProperties: false, - additionalItems: false, - items: { - type: 'string', - enum: [ - ServiceHealthStatus.unknown, - ServiceHealthStatus.healthy, - ServiceHealthStatus.warning, - ServiceHealthStatus.critical, - ], - }, - }, - }, - required: ['start', 'end'], - } as const, - }, - async ({ arguments: args }, signal) => { - return callApmApi('POST /internal/apm/assistant/get_services_list', { - signal, - params: { - body: args, - }, - }); - } - ); -} diff --git a/x-pack/plugins/apm/public/assistant_functions/get_apm_timeseries.tsx b/x-pack/plugins/apm/public/assistant_functions/get_apm_timeseries.tsx index 56269d884ad84c..445610321cbb8c 100644 --- a/x-pack/plugins/apm/public/assistant_functions/get_apm_timeseries.tsx +++ b/x-pack/plugins/apm/public/assistant_functions/get_apm_timeseries.tsx @@ -5,294 +5,164 @@ * 2.0. */ import { EuiFlexGroup, EuiFlexItem, EuiText } from '@elastic/eui'; -import type { RegisterFunctionDefinition } from '@kbn/observability-ai-assistant-plugin/common/types'; +import { useKibana } from '@kbn/kibana-react-plugin/public'; +import type { + RegisterRenderFunctionDefinition, + RenderFunction, +} from '@kbn/observability-ai-assistant-plugin/public/types'; + import { groupBy } from 'lodash'; import React from 'react'; -import { i18n } from '@kbn/i18n'; -import { useKibana } from '@kbn/kibana-react-plugin/public'; -import { FETCH_STATUS } from '../hooks/use_fetcher'; -import { callApmApi } from '../services/rest/create_call_apm_api'; -import { getTimeZone } from '../components/shared/charts/helper/timezone'; -import { TimeseriesChart } from '../components/shared/charts/timeseries_chart'; -import { ChartPointerEventContextProvider } from '../context/chart_pointer_event/chart_pointer_event_context'; -import { ApmThemeProvider } from '../components/routing/app_root'; -import { Coordinate, TimeSeries } from '../../typings/timeseries'; -import { - ChartType, - getTimeSeriesColor, -} from '../components/shared/charts/helper/get_timeseries_color'; import { LatencyAggregationType } from '../../common/latency_aggregation_types'; import { asPercent, asTransactionRate, getDurationFormatter, } from '../../common/utils/formatters'; +import type { + GetApmTimeseriesFunctionArguments, + GetApmTimeseriesFunctionResponse, +} from '../../server/assistant_functions/get_apm_timeseries'; +import { Coordinate, TimeSeries } from '../../typings/timeseries'; +import { ApmThemeProvider } from '../components/routing/app_root'; +import { + ChartType, + getTimeSeriesColor, +} from '../components/shared/charts/helper/get_timeseries_color'; +import { getTimeZone } from '../components/shared/charts/helper/timezone'; +import { TimeseriesChart } from '../components/shared/charts/timeseries_chart'; import { getMaxY, getResponseTimeTickFormatter, } from '../components/shared/charts/transaction_charts/helper'; -import { NON_EMPTY_STRING } from '../utils/non_empty_string_ref'; +import { ChartPointerEventContextProvider } from '../context/chart_pointer_event/chart_pointer_event_context'; +import { FETCH_STATUS } from '../hooks/use_fetcher'; export function registerGetApmTimeseriesFunction({ - registerFunction, + registerRenderFunction, }: { - registerFunction: RegisterFunctionDefinition; + registerRenderFunction: RegisterRenderFunctionDefinition; }) { - registerFunction( - { - contexts: ['apm'], - name: 'get_apm_timeseries', - descriptionForUser: i18n.translate( - 'xpack.apm.observabilityAiAssistant.functions.registerGetApmTimeseries.descriptionForUser', - { - defaultMessage: `Display different APM metrics, like throughput, failure rate, or latency, for any service or all services, or any or all of its dependencies, both as a timeseries and as a single statistic. Additionally, the function will return any changes, such as spikes, step and trend changes, or dips. You can also use it to compare data by requesting two different time ranges, or for instance two different service versions`, - } - ), - description: `Visualise and analyse different APM metrics, like throughput, failure rate, or latency, for any service or all services, or any or all of its dependencies, both as a timeseries and as a single statistic. A visualisation will be displayed above your reply - DO NOT attempt to display or generate an image yourself, or any other placeholder. Additionally, the function will return any changes, such as spikes, step and trend changes, or dips. You can also use it to compare data by requesting two different time ranges, or for instance two different service versions.`, - parameters: { - type: 'object', - properties: { - start: { - type: 'string', - description: - 'The start of the time range, in Elasticsearch date math, like `now`.', - }, - end: { - type: 'string', - description: - 'The end of the time range, in Elasticsearch date math, like `now-24h`.', - }, - stats: { - type: 'array', - items: { - type: 'object', - properties: { - timeseries: { - description: 'The metric to be displayed', - oneOf: [ - { - type: 'object', - properties: { - name: { - type: 'string', - enum: [ - 'transaction_throughput', - 'transaction_failure_rate', - ], - }, - 'transaction.type': { - type: 'string', - description: 'The transaction type', - }, - }, - required: ['name'], - }, - { - type: 'object', - properties: { - name: { - type: 'string', - enum: [ - 'exit_span_throughput', - 'exit_span_failure_rate', - 'exit_span_latency', - ], - }, - 'span.destination.service.resource': { - type: 'string', - description: - 'The name of the downstream dependency for the service', - }, - }, - required: ['name'], - }, - { - type: 'object', - properties: { - name: { - type: 'string', - const: 'error_event_rate', - }, - }, - required: ['name'], - }, - { - type: 'object', - properties: { - name: { - type: 'string', - const: 'transaction_latency', - }, - 'transaction.type': { - type: 'string', - }, - function: { - type: 'string', - enum: ['avg', 'p95', 'p99'], - }, - }, - required: ['name', 'function'], - }, - ], - }, - 'service.name': { - ...NON_EMPTY_STRING, - description: 'The name of the service', - }, - 'service.environment': { - description: - 'The environment that the service is running in. If undefined, all environments will be included. Only use this if you have confirmed the environment that the service is running in.', - }, - filter: { - type: 'string', - description: - 'a KQL query to filter the data by. If no filter should be applied, leave it empty.', - }, - title: { - type: 'string', - description: - 'A unique, human readable, concise title for this specific group series.', - }, - offset: { - type: 'string', - description: - 'The offset. Right: 15m. 8h. 1d. Wrong: -15m. -8h. -1d.', - }, - }, - required: ['service.name', 'timeseries', 'title'], - }, - }, - }, - required: ['stats', 'start', 'end'], - } as const, - }, - async ({ arguments: { stats, start, end } }, signal) => { - const response = await callApmApi( - 'POST /internal/apm/assistant/get_apm_timeseries', - { - signal, - params: { - body: { stats: stats as any, start, end }, - }, - } - ); - - return response; - }, - ({ arguments: args, response }) => { - const groupedSeries = groupBy(response.data, (series) => series.group); - - const { - services: { uiSettings }, - } = useKibana(); - - const timeZone = getTimeZone(uiSettings); - - return ( - - - - {Object.values(groupedSeries).map((groupSeries) => { - const groupId = groupSeries[0].group; - - const maxY = getMaxY(groupSeries); - const latencyFormatter = getDurationFormatter(maxY, 10, 1000); - - let yLabelFormat: (value: number) => string; - - const firstStat = groupSeries[0].stat; - - switch (firstStat.timeseries.name) { - case 'transaction_throughput': - case 'exit_span_throughput': - case 'error_event_rate': - yLabelFormat = asTransactionRate; - break; - - case 'transaction_latency': - case 'exit_span_latency': - yLabelFormat = - getResponseTimeTickFormatter(latencyFormatter); - break; - - case 'transaction_failure_rate': - case 'exit_span_failure_rate': - yLabelFormat = (y) => asPercent(y || 0, 100); - break; - } - - const timeseries: Array> = - groupSeries.map((series): TimeSeries => { - let chartType: ChartType; - - const data = series.data; - - switch (series.stat.timeseries.name) { - case 'transaction_throughput': - case 'exit_span_throughput': - chartType = ChartType.THROUGHPUT; - break; - - case 'transaction_failure_rate': - case 'exit_span_failure_rate': - chartType = ChartType.FAILED_TRANSACTION_RATE; - break; - - case 'transaction_latency': - if ( - series.stat.timeseries.function === - LatencyAggregationType.p99 - ) { - chartType = ChartType.LATENCY_P99; - } else if ( - series.stat.timeseries.function === - LatencyAggregationType.p95 - ) { - chartType = ChartType.LATENCY_P95; - } else { - chartType = ChartType.LATENCY_AVG; - } - break; - - case 'exit_span_latency': + registerRenderFunction('get_apm_timeseries', (parameters) => { + const { response } = parameters as Parameters< + RenderFunction< + GetApmTimeseriesFunctionArguments, + GetApmTimeseriesFunctionResponse + > + >[0]; + + const groupedSeries = groupBy(response.data, (series) => series.group); + + const { + services: { uiSettings }, + } = useKibana(); + + const timeZone = getTimeZone(uiSettings); + + return ( + + + + {Object.values(groupedSeries).map((groupSeries) => { + const groupId = groupSeries[0].group; + + const maxY = getMaxY(groupSeries); + const latencyFormatter = getDurationFormatter(maxY, 10, 1000); + + let yLabelFormat: (value: number) => string; + + const firstStat = groupSeries[0].stat; + + switch (firstStat.timeseries.name) { + case 'transaction_throughput': + case 'exit_span_throughput': + case 'error_event_rate': + yLabelFormat = asTransactionRate; + break; + + case 'transaction_latency': + case 'exit_span_latency': + yLabelFormat = getResponseTimeTickFormatter(latencyFormatter); + break; + + case 'transaction_failure_rate': + case 'exit_span_failure_rate': + yLabelFormat = (y) => asPercent(y || 0, 100); + break; + } + + const timeseries: Array> = groupSeries.map( + (series): TimeSeries => { + let chartType: ChartType; + + const data = series.data; + + switch (series.stat.timeseries.name) { + case 'transaction_throughput': + case 'exit_span_throughput': + chartType = ChartType.THROUGHPUT; + break; + + case 'transaction_failure_rate': + case 'exit_span_failure_rate': + chartType = ChartType.FAILED_TRANSACTION_RATE; + break; + + case 'transaction_latency': + if ( + series.stat.timeseries.function === + LatencyAggregationType.p99 + ) { + chartType = ChartType.LATENCY_P99; + } else if ( + series.stat.timeseries.function === + LatencyAggregationType.p95 + ) { + chartType = ChartType.LATENCY_P95; + } else { chartType = ChartType.LATENCY_AVG; - break; - - case 'error_event_rate': - chartType = ChartType.ERROR_OCCURRENCES; - break; - } - - return { - title: series.id, - type: 'line', - color: getTimeSeriesColor(chartType!).currentPeriodColor, - data, - }; - }); - - return ( - - - - {groupId} - - - - - ); - })} - - - - ); - } - ); + } + break; + + case 'exit_span_latency': + chartType = ChartType.LATENCY_AVG; + break; + + case 'error_event_rate': + chartType = ChartType.ERROR_OCCURRENCES; + break; + } + + return { + title: series.id, + type: 'line', + color: getTimeSeriesColor(chartType!).currentPeriodColor, + data, + }; + } + ); + + return ( + + + + {groupId} + + + + + ); + })} + + + + ); + }); } diff --git a/x-pack/plugins/apm/public/assistant_functions/index.ts b/x-pack/plugins/apm/public/assistant_functions/index.ts index 128091cf2472de..773d2fbfba27fd 100644 --- a/x-pack/plugins/apm/public/assistant_functions/index.ts +++ b/x-pack/plugins/apm/public/assistant_functions/index.ts @@ -5,152 +5,15 @@ * 2.0. */ -import type { CoreStart } from '@kbn/core/public'; -import type { - RegisterContextDefinition, - RegisterFunctionDefinition, -} from '@kbn/observability-ai-assistant-plugin/common/types'; -import type { ApmPluginStartDeps } from '../plugin'; -import { - createCallApmApi, - callApmApi, -} from '../services/rest/create_call_apm_api'; -import { registerGetApmCorrelationsFunction } from './get_apm_correlations'; -import { registerGetApmDownstreamDependenciesFunction } from './get_apm_downstream_dependencies'; -import { registerGetApmErrorDocumentFunction } from './get_apm_error_document'; -import { registerGetApmServicesListFunction } from './get_apm_services_list'; -import { registerGetApmServiceSummaryFunction } from './get_apm_service_summary'; +import { RegisterRenderFunctionDefinition } from '@kbn/observability-ai-assistant-plugin/public/types'; import { registerGetApmTimeseriesFunction } from './get_apm_timeseries'; export async function registerAssistantFunctions({ - pluginsStart, - coreStart, - registerContext, - registerFunction, - signal, + registerRenderFunction, }: { - pluginsStart: ApmPluginStartDeps; - coreStart: CoreStart; - registerContext: RegisterContextDefinition; - registerFunction: RegisterFunctionDefinition; - signal: AbortSignal; + registerRenderFunction: RegisterRenderFunctionDefinition; }) { - createCallApmApi(coreStart); - - const response = await callApmApi('GET /internal/apm/has_data', { - signal, - }); - - if (!response.hasData) { - return; - } - registerGetApmTimeseriesFunction({ - registerFunction, - }); - - registerGetApmErrorDocumentFunction({ - registerFunction, - }); - - registerGetApmCorrelationsFunction({ - registerFunction, - }); - - registerGetApmDownstreamDependenciesFunction({ - registerFunction, - }); - - registerGetApmServiceSummaryFunction({ - registerFunction, - }); - - registerGetApmServicesListFunction({ - registerFunction, - }); - - registerContext({ - name: 'apm', - description: ` - When analyzing APM data, prefer the APM specific functions over the generic Lens, - Elasticsearch or Kibana ones, unless those are explicitly requested by the user. - - When requesting metrics for a service, make sure you also know what environment - it is running in. Metrics aggregated over multiple environments are useless. - - There are four important data types in Elastic APM. Each of them have the - following fields: - - service.name: the name of the service - - service.node.name: the id of the service instance (often the hostname) - - service.environment: the environment (often production, development) - - agent.name: the name of the agent (go, java, etc) - - The four data types are transactions, exit spans, error events, and application - metrics. - - Transactions have three metrics: throughput, failure rate, and latency. The - fields are: - - - transaction.type: often request or page-load (the main transaction types), - but can also be worker, or route-change. - - transaction.name: The name of the transaction group, often something like - 'GET /api/product/:productId' - - transaction.result: The result. Used to capture HTTP response codes - (2xx,3xx,4xx,5xx) for request transactions. - - event.outcome: whether the transaction was succesful or not. success, - failure, or unknown. - - Exit spans have three metrics: throughput, failure rate and latency. The fields - are: - - span.type: db, external - - span.subtype: the type of database (redis, postgres) or protocol (http, grpc) - - span.destination.service.resource: the address of the destination of the call - - event.outcome: whether the transaction was succesful or not. success, - failure, or unknown. - - Error events have one metric, error event rate. The fields are: - - error.grouping_name: a human readable keyword that identifies the error group - - For transaction metrics we also collect anomalies. These are scored 0 (low) to - 100 (critical). - - For root cause analysis, locate a change point in the relevant metrics for a - service or downstream dependency. You can locate a change point by using a - sliding window, e.g. start with a small time range, like 30m, and make it - bigger until you identify a change point. It's very important to identify a - change point. If you don't have a change point, ask the user for next steps. - You can also use an anomaly or a deployment as a change point. Then, compare - data before the change with data after the change. You can either use the - groupBy parameter in get_apm_chart to get the most occuring values in a certain - data set, or you can use correlations to see for which field and value the - frequency has changed when comparing the foreground set to the background set. - This is useful when comparing data from before the change point with after the - change point. For instance, you might see a specific error pop up more often - after the change point. - - When comparing anomalies and changes in timeseries, first, zoom in to a smaller - time window, at least 30 minutes before and 30 minutes after the change - occured. E.g., if the anomaly occured at 2023-07-05T08:15:00.000Z, request a - time window that starts at 2023-07-05T07:45:00.000Z and ends at - 2023-07-05T08:45:00.000Z. When comparing changes in different timeseries and - anomalies to determine a correlation, make sure to compare the timestamps. If - in doubt, rate the likelihood of them being related, given the time difference, - between 1 and 10. If below 5, assume it's not related. Mention this likelihood - (and the time difference) to the user. - - Your goal is to help the user determine the root cause of an issue quickly and - transparently. If you see a change or - anomaly in a metric for a service, try to find similar changes in the metrics - for the traffic to its downstream dependencies, by comparing transaction - metrics to span metrics. To inspect the traffic from one service to a - downstream dependency, first get the downstream dependencies for a service, - then get the span metrics from that service (\`service.name\`) to its - downstream dependency (\`span.destination.service.resource\`). For instance, - for an anomaly in throughput, first inspect \`transaction_throughput\` for - \`service.name\`. Then, inspect \`exit_span_throughput\` for its downstream - dependencies, by grouping by \`span.destination.service.resource\`. Repeat this - process over the next service its downstream dependencies until you identify a - root cause. If you can not find any similar changes, use correlations or - grouping to find attributes that could be causes for the change.`, + registerRenderFunction, }); } diff --git a/x-pack/plugins/apm/public/plugin.ts b/x-pack/plugins/apm/public/plugin.ts index 80dfa77ca40d0f..88d0628b46f13d 100644 --- a/x-pack/plugins/apm/public/plugin.ts +++ b/x-pack/plugins/apm/public/plugin.ts @@ -428,15 +428,11 @@ export class ApmPlugin implements Plugin { const { fleet } = plugins; plugins.observabilityAIAssistant.register( - async ({ signal, registerContext, registerFunction }) => { + async ({ registerRenderFunction }) => { const mod = await import('./assistant_functions'); mod.registerAssistantFunctions({ - coreStart: core, - pluginsStart: plugins, - registerContext, - registerFunction, - signal, + registerRenderFunction, }); } ); diff --git a/x-pack/plugins/apm/public/assistant_functions/get_apm_correlations.ts b/x-pack/plugins/apm/server/assistant_functions/get_apm_correlations.ts similarity index 91% rename from x-pack/plugins/apm/public/assistant_functions/get_apm_correlations.ts rename to x-pack/plugins/apm/server/assistant_functions/get_apm_correlations.ts index 2d600e41e48604..f3f61f7d169376 100644 --- a/x-pack/plugins/apm/public/assistant_functions/get_apm_correlations.ts +++ b/x-pack/plugins/apm/server/assistant_functions/get_apm_correlations.ts @@ -6,15 +6,14 @@ */ import { i18n } from '@kbn/i18n'; -import type { RegisterFunctionDefinition } from '@kbn/observability-ai-assistant-plugin/common/types'; +import { FunctionRegistrationParameters } from '.'; import { CorrelationsEventType } from '../../common/assistant/constants'; -import { callApmApi } from '../services/rest/create_call_apm_api'; +import { getApmCorrelationValues } from '../routes/assistant_functions/get_apm_correlation_values'; export function registerGetApmCorrelationsFunction({ + apmEventClient, registerFunction, -}: { - registerFunction: RegisterFunctionDefinition; -}) { +}: FunctionRegistrationParameters) { registerFunction( { name: 'get_apm_correlations', @@ -113,12 +112,12 @@ export function registerGetApmCorrelationsFunction({ } as const, }, async ({ arguments: args }, signal) => { - return callApmApi('POST /internal/apm/assistant/get_correlation_values', { - signal, - params: { - body: args, - }, - }); + return { + content: await getApmCorrelationValues({ + arguments: args as any, + apmEventClient, + }), + }; } ); } diff --git a/x-pack/plugins/apm/public/assistant_functions/get_apm_downstream_dependencies.ts b/x-pack/plugins/apm/server/assistant_functions/get_apm_downstream_dependencies.ts similarity index 83% rename from x-pack/plugins/apm/public/assistant_functions/get_apm_downstream_dependencies.ts rename to x-pack/plugins/apm/server/assistant_functions/get_apm_downstream_dependencies.ts index fea837dd76c31e..45c1b876974aac 100644 --- a/x-pack/plugins/apm/public/assistant_functions/get_apm_downstream_dependencies.ts +++ b/x-pack/plugins/apm/server/assistant_functions/get_apm_downstream_dependencies.ts @@ -6,14 +6,13 @@ */ import { i18n } from '@kbn/i18n'; -import type { RegisterFunctionDefinition } from '@kbn/observability-ai-assistant-plugin/common/types'; -import { callApmApi } from '../services/rest/create_call_apm_api'; +import { FunctionRegistrationParameters } from '.'; +import { getAssistantDownstreamDependencies } from '../routes/assistant_functions/get_apm_downstream_dependencies'; export function registerGetApmDownstreamDependenciesFunction({ + apmEventClient, registerFunction, -}: { - registerFunction: RegisterFunctionDefinition; -}) { +}: FunctionRegistrationParameters) { registerFunction( { name: 'get_apm_downstream_dependencies', @@ -57,15 +56,12 @@ export function registerGetApmDownstreamDependenciesFunction({ } as const, }, async ({ arguments: args }, signal) => { - return callApmApi( - 'GET /internal/apm/assistant/get_downstream_dependencies', - { - signal, - params: { - query: args, - }, - } - ); + return { + content: await getAssistantDownstreamDependencies({ + arguments: args, + apmEventClient, + }), + }; } ); } diff --git a/x-pack/plugins/apm/public/assistant_functions/get_apm_error_document.ts b/x-pack/plugins/apm/server/assistant_functions/get_apm_error_document.ts similarity index 83% rename from x-pack/plugins/apm/public/assistant_functions/get_apm_error_document.ts rename to x-pack/plugins/apm/server/assistant_functions/get_apm_error_document.ts index a5c66478e3fd78..e5082f47ad8eb6 100644 --- a/x-pack/plugins/apm/public/assistant_functions/get_apm_error_document.ts +++ b/x-pack/plugins/apm/server/assistant_functions/get_apm_error_document.ts @@ -6,14 +6,13 @@ */ import { i18n } from '@kbn/i18n'; -import type { RegisterFunctionDefinition } from '@kbn/observability-ai-assistant-plugin/common/types'; -import { callApmApi } from '../services/rest/create_call_apm_api'; +import type { FunctionRegistrationParameters } from '.'; +import { getApmErrorDocument } from '../routes/assistant_functions/get_apm_error_document'; export function registerGetApmErrorDocumentFunction({ + apmEventClient, registerFunction, -}: { - registerFunction: RegisterFunctionDefinition; -}) { +}: FunctionRegistrationParameters) { registerFunction( { name: 'get_apm_error_document', @@ -55,12 +54,12 @@ export function registerGetApmErrorDocumentFunction({ } as const, }, async ({ arguments: args }, signal) => { - return callApmApi('GET /internal/apm/assistant/get_error_document', { - signal, - params: { - query: args, - }, - }); + return { + content: await getApmErrorDocument({ + apmEventClient, + arguments: args, + }), + }; } ); } diff --git a/x-pack/plugins/apm/public/assistant_functions/get_apm_service_summary.ts b/x-pack/plugins/apm/server/assistant_functions/get_apm_service_summary.ts similarity index 64% rename from x-pack/plugins/apm/public/assistant_functions/get_apm_service_summary.ts rename to x-pack/plugins/apm/server/assistant_functions/get_apm_service_summary.ts index 189633ec959750..291e4fdae33b07 100644 --- a/x-pack/plugins/apm/public/assistant_functions/get_apm_service_summary.ts +++ b/x-pack/plugins/apm/server/assistant_functions/get_apm_service_summary.ts @@ -5,16 +5,19 @@ * 2.0. */ +import type { ElasticsearchClient } from '@kbn/core-elasticsearch-server'; import { i18n } from '@kbn/i18n'; -import type { RegisterFunctionDefinition } from '@kbn/observability-ai-assistant-plugin/common/types'; -import { callApmApi } from '../services/rest/create_call_apm_api'; +import type { FunctionRegistrationParameters } from '.'; +import { getApmAlertsClient } from '../lib/helpers/get_apm_alerts_client'; +import { getMlClient } from '../lib/helpers/get_ml_client'; +import { getApmServiceSummary } from '../routes/assistant_functions/get_apm_service_summary'; import { NON_EMPTY_STRING } from '../utils/non_empty_string_ref'; export function registerGetApmServiceSummaryFunction({ + resources, + apmEventClient, registerFunction, -}: { - registerFunction: RegisterFunctionDefinition; -}) { +}: FunctionRegistrationParameters) { registerFunction( { name: 'get_apm_service_summary', @@ -58,12 +61,33 @@ alerts and anomalies.`, } as const, }, async ({ arguments: args }, signal) => { - return callApmApi('GET /internal/apm/assistant/get_service_summary', { - signal, - params: { - query: args, - }, - }); + const { context, request, plugins, logger } = resources; + + const [annotationsClient, esClient, apmAlertsClient, mlClient] = + await Promise.all([ + plugins.observability.setup.getScopedAnnotationsClient( + context, + request + ), + context.core.then( + (coreContext): ElasticsearchClient => + coreContext.elasticsearch.client.asCurrentUser + ), + getApmAlertsClient(resources), + getMlClient(resources), + ]); + + return { + content: await getApmServiceSummary({ + apmEventClient, + annotationsClient, + esClient, + apmAlertsClient, + mlClient, + logger, + arguments: args, + }), + }; } ); } diff --git a/x-pack/plugins/apm/server/assistant_functions/get_apm_services_list.ts b/x-pack/plugins/apm/server/assistant_functions/get_apm_services_list.ts new file mode 100644 index 00000000000000..178d47f8d33edf --- /dev/null +++ b/x-pack/plugins/apm/server/assistant_functions/get_apm_services_list.ts @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import datemath from '@elastic/datemath'; +import { i18n } from '@kbn/i18n'; +import { FunctionRegistrationParameters } from '.'; +import { ApmDocumentType } from '../../common/document_type'; +import { ENVIRONMENT_ALL } from '../../common/environment_filter_values'; +import { RollupInterval } from '../../common/rollup'; +import { ServiceHealthStatus } from '../../common/service_health_status'; +import { getApmAlertsClient } from '../lib/helpers/get_apm_alerts_client'; +import { getMlClient } from '../lib/helpers/get_ml_client'; +import { getRandomSampler } from '../lib/helpers/get_random_sampler'; +import { getServicesItems } from '../routes/services/get_services/get_services_items'; +import { NON_EMPTY_STRING } from '../utils/non_empty_string_ref'; + +export interface ApmServicesListItem { + 'service.name': string; + 'agent.name'?: string; + 'transaction.type'?: string; + alertsCount: number; + healthStatus: ServiceHealthStatus; + 'service.environment'?: string[]; +} + +export function registerGetApmServicesListFunction({ + apmEventClient, + resources, + registerFunction, +}: FunctionRegistrationParameters) { + registerFunction( + { + name: 'get_apm_services_list', + contexts: ['apm'], + description: `Gets a list of services`, + descriptionForUser: i18n.translate( + 'xpack.apm.observabilityAiAssistant.functions.registerGetApmServicesList.descriptionForUser', + { + defaultMessage: `Gets the list of monitored services, their health status, and alerts.`, + } + ), + parameters: { + type: 'object', + additionalProperties: false, + properties: { + 'service.environment': { + ...NON_EMPTY_STRING, + description: + 'Optionally filter the services by the environments that they are running in', + }, + start: { + ...NON_EMPTY_STRING, + description: + 'The start of the time range, in Elasticsearch date math, like `now`.', + }, + end: { + ...NON_EMPTY_STRING, + description: + 'The end of the time range, in Elasticsearch date math, like `now-24h`.', + }, + healthStatus: { + type: 'array', + description: 'Filter service list by health status', + additionalProperties: false, + additionalItems: false, + items: { + type: 'string', + enum: [ + ServiceHealthStatus.unknown, + ServiceHealthStatus.healthy, + ServiceHealthStatus.warning, + ServiceHealthStatus.critical, + ], + }, + }, + }, + required: ['start', 'end'], + } as const, + }, + async ({ arguments: args }, signal) => { + const { healthStatus } = args; + const [apmAlertsClient, mlClient, randomSampler] = await Promise.all([ + getApmAlertsClient(resources), + getMlClient(resources), + getRandomSampler({ + security: resources.plugins.security, + probability: 1, + request: resources.request, + }), + ]); + + const start = datemath.parse(args.start)?.valueOf()!; + const end = datemath.parse(args.end)?.valueOf()!; + + const serviceItems = await getServicesItems({ + apmAlertsClient, + apmEventClient, + documentType: ApmDocumentType.TransactionMetric, + start, + end, + environment: args['service.environment'] || ENVIRONMENT_ALL.value, + kuery: '', + logger: resources.logger, + randomSampler, + rollupInterval: RollupInterval.OneMinute, + serviceGroup: null, + mlClient, + useDurationSummary: false, + }); + + let mappedItems = serviceItems.items.map((item): ApmServicesListItem => { + return { + 'service.name': item.serviceName, + 'agent.name': item.agentName, + alertsCount: item.alertsCount ?? 0, + healthStatus: item.healthStatus ?? ServiceHealthStatus.unknown, + 'service.environment': item.environments, + 'transaction.type': item.transactionType, + }; + }); + + if (healthStatus && healthStatus.length) { + mappedItems = mappedItems.filter((item): boolean => + healthStatus.includes(item.healthStatus) + ); + } + + return { + content: mappedItems, + }; + } + ); +} diff --git a/x-pack/plugins/apm/server/assistant_functions/get_apm_timeseries.ts b/x-pack/plugins/apm/server/assistant_functions/get_apm_timeseries.ts new file mode 100644 index 00000000000000..0d5be32611eb43 --- /dev/null +++ b/x-pack/plugins/apm/server/assistant_functions/get_apm_timeseries.ts @@ -0,0 +1,174 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { i18n } from '@kbn/i18n'; +import { FromSchema } from 'json-schema-to-ts'; +import { omit } from 'lodash'; +import { FunctionRegistrationParameters } from '.'; +import { + ApmTimeseries, + getApmTimeseries, +} from '../routes/assistant_functions/get_apm_timeseries'; +import { NON_EMPTY_STRING } from '../utils/non_empty_string_ref'; + +const parameters = { + type: 'object', + properties: { + start: { + type: 'string', + description: + 'The start of the time range, in Elasticsearch date math, like `now`.', + }, + end: { + type: 'string', + description: + 'The end of the time range, in Elasticsearch date math, like `now-24h`.', + }, + stats: { + type: 'array', + items: { + type: 'object', + properties: { + timeseries: { + description: 'The metric to be displayed', + oneOf: [ + { + type: 'object', + properties: { + name: { + type: 'string', + enum: [ + 'transaction_throughput', + 'transaction_failure_rate', + ], + }, + 'transaction.type': { + type: 'string', + description: 'The transaction type', + }, + }, + required: ['name'], + }, + { + type: 'object', + properties: { + name: { + type: 'string', + enum: [ + 'exit_span_throughput', + 'exit_span_failure_rate', + 'exit_span_latency', + ], + }, + 'span.destination.service.resource': { + type: 'string', + description: + 'The name of the downstream dependency for the service', + }, + }, + required: ['name'], + }, + { + type: 'object', + properties: { + name: { + type: 'string', + const: 'error_event_rate', + }, + }, + required: ['name'], + }, + { + type: 'object', + properties: { + name: { + type: 'string', + const: 'transaction_latency', + }, + 'transaction.type': { + type: 'string', + }, + function: { + type: 'string', + enum: ['avg', 'p95', 'p99'], + }, + }, + required: ['name', 'function'], + }, + ], + }, + 'service.name': { + ...NON_EMPTY_STRING, + description: 'The name of the service', + }, + 'service.environment': { + description: + 'The environment that the service is running in. If undefined, all environments will be included. Only use this if you have confirmed the environment that the service is running in.', + }, + filter: { + type: 'string', + description: + 'a KQL query to filter the data by. If no filter should be applied, leave it empty.', + }, + title: { + type: 'string', + description: + 'A unique, human readable, concise title for this specific group series.', + }, + offset: { + type: 'string', + description: + 'The offset. Right: 15m. 8h. 1d. Wrong: -15m. -8h. -1d.', + }, + }, + required: ['service.name', 'timeseries', 'title'], + }, + }, + }, + required: ['stats', 'start', 'end'], +} as const; + +export function registerGetApmTimeseriesFunction({ + apmEventClient, + registerFunction, +}: FunctionRegistrationParameters) { + registerFunction( + { + contexts: ['apm'], + name: 'get_apm_timeseries', + descriptionForUser: i18n.translate( + 'xpack.apm.observabilityAiAssistant.functions.registerGetApmTimeseries.descriptionForUser', + { + defaultMessage: `Display different APM metrics, like throughput, failure rate, or latency, for any service or all services, or any or all of its dependencies, both as a timeseries and as a single statistic. Additionally, the function will return any changes, such as spikes, step and trend changes, or dips. You can also use it to compare data by requesting two different time ranges, or for instance two different service versions`, + } + ), + description: `Visualise and analyse different APM metrics, like throughput, failure rate, or latency, for any service or all services, or any or all of its dependencies, both as a timeseries and as a single statistic. A visualisation will be displayed above your reply - DO NOT attempt to display or generate an image yourself, or any other placeholder. Additionally, the function will return any changes, such as spikes, step and trend changes, or dips. You can also use it to compare data by requesting two different time ranges, or for instance two different service versions.`, + parameters, + }, + async ( + { arguments: args }, + signal + ): Promise => { + const timeseries = await getApmTimeseries({ + apmEventClient, + arguments: args as any, + }); + + return { + content: timeseries.map( + (series): Omit => omit(series, 'data') + ), + data: timeseries, + }; + } + ); +} + +export type GetApmTimeseriesFunctionArguments = FromSchema; +export interface GetApmTimeseriesFunctionResponse { + content: Array>; + data: ApmTimeseries[]; +} diff --git a/x-pack/plugins/apm/server/assistant_functions/index.ts b/x-pack/plugins/apm/server/assistant_functions/index.ts new file mode 100644 index 00000000000000..81521a233ad1d6 --- /dev/null +++ b/x-pack/plugins/apm/server/assistant_functions/index.ts @@ -0,0 +1,186 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { CoreSetup } from '@kbn/core-lifecycle-server'; +import type { Logger } from '@kbn/logging'; +import type { + ChatRegistrationFunction, + RegisterFunction, +} from '@kbn/observability-ai-assistant-plugin/server/service/types'; +import type { IRuleDataClient } from '@kbn/rule-registry-plugin/server'; +import type { APMConfig } from '..'; +import type { ApmFeatureFlags } from '../../common/apm_feature_flags'; +import { APMEventClient } from '../lib/helpers/create_es_client/create_apm_event_client'; +import { getApmEventClient } from '../lib/helpers/get_apm_event_client'; +import type { APMRouteHandlerResources } from '../routes/apm_routes/register_apm_server_routes'; +import { hasHistoricalAgentData } from '../routes/historical_data/has_historical_agent_data'; +import { registerGetApmCorrelationsFunction } from './get_apm_correlations'; +import { registerGetApmDownstreamDependenciesFunction } from './get_apm_downstream_dependencies'; +import { registerGetApmErrorDocumentFunction } from './get_apm_error_document'; +import { registerGetApmServicesListFunction } from './get_apm_services_list'; +import { registerGetApmServiceSummaryFunction } from './get_apm_service_summary'; +import { registerGetApmTimeseriesFunction } from './get_apm_timeseries'; + +export interface FunctionRegistrationParameters { + apmEventClient: APMEventClient; + registerFunction: RegisterFunction; + resources: APMRouteHandlerResources; +} + +export function registerAssistantFunctions({ + coreSetup, + config, + featureFlags, + logger, + kibanaVersion, + ruleDataClient, + plugins, +}: { + coreSetup: CoreSetup; + config: APMConfig; + featureFlags: ApmFeatureFlags; + logger: Logger; + kibanaVersion: string; + ruleDataClient: IRuleDataClient; + plugins: APMRouteHandlerResources['plugins']; +}): ChatRegistrationFunction { + return async ({ resources, registerContext, registerFunction }) => { + const apmRouteHandlerResources: APMRouteHandlerResources = { + context: resources.context, + request: resources.request, + core: { + setup: coreSetup, + start: () => + coreSetup.getStartServices().then(([coreStart]) => coreStart), + }, + params: { + query: { + _inspect: false, + }, + }, + config, + featureFlags, + logger, + kibanaVersion, + ruleDataClient, + plugins, + getApmIndices: async () => { + const coreContext = await resources.context.core; + const apmIndices = await plugins.apmDataAccess.setup.getApmIndices( + coreContext.savedObjects.client + ); + return apmIndices; + }, + }; + + const apmEventClient = await getApmEventClient(apmRouteHandlerResources); + + const hasData = await hasHistoricalAgentData(apmEventClient); + + if (!hasData) { + return; + } + + const parameters: FunctionRegistrationParameters = { + resources: apmRouteHandlerResources, + apmEventClient, + registerFunction, + }; + + registerGetApmServicesListFunction(parameters); + registerGetApmServiceSummaryFunction(parameters); + registerGetApmErrorDocumentFunction(parameters); + registerGetApmDownstreamDependenciesFunction(parameters); + registerGetApmCorrelationsFunction(parameters); + registerGetApmTimeseriesFunction(parameters); + + registerContext({ + name: 'apm', + description: ` + When analyzing APM data, prefer the APM specific functions over the generic Lens, + Elasticsearch or Kibana ones, unless those are explicitly requested by the user. + + When requesting metrics for a service, make sure you also know what environment + it is running in. Metrics aggregated over multiple environments are useless. + + There are four important data types in Elastic APM. Each of them have the + following fields: + - service.name: the name of the service + - service.node.name: the id of the service instance (often the hostname) + - service.environment: the environment (often production, development) + - agent.name: the name of the agent (go, java, etc) + + The four data types are transactions, exit spans, error events, and application + metrics. + + Transactions have three metrics: throughput, failure rate, and latency. The + fields are: + + - transaction.type: often request or page-load (the main transaction types), + but can also be worker, or route-change. + - transaction.name: The name of the transaction group, often something like + 'GET /api/product/:productId' + - transaction.result: The result. Used to capture HTTP response codes + (2xx,3xx,4xx,5xx) for request transactions. + - event.outcome: whether the transaction was succesful or not. success, + failure, or unknown. + + Exit spans have three metrics: throughput, failure rate and latency. The fields + are: + - span.type: db, external + - span.subtype: the type of database (redis, postgres) or protocol (http, grpc) + - span.destination.service.resource: the address of the destination of the call + - event.outcome: whether the transaction was succesful or not. success, + failure, or unknown. + + Error events have one metric, error event rate. The fields are: + - error.grouping_name: a human readable keyword that identifies the error group + + For transaction metrics we also collect anomalies. These are scored 0 (low) to + 100 (critical). + + For root cause analysis, locate a change point in the relevant metrics for a + service or downstream dependency. You can locate a change point by using a + sliding window, e.g. start with a small time range, like 30m, and make it + bigger until you identify a change point. It's very important to identify a + change point. If you don't have a change point, ask the user for next steps. + You can also use an anomaly or a deployment as a change point. Then, compare + data before the change with data after the change. You can either use the + groupBy parameter in get_apm_chart to get the most occuring values in a certain + data set, or you can use correlations to see for which field and value the + frequency has changed when comparing the foreground set to the background set. + This is useful when comparing data from before the change point with after the + change point. For instance, you might see a specific error pop up more often + after the change point. + + When comparing anomalies and changes in timeseries, first, zoom in to a smaller + time window, at least 30 minutes before and 30 minutes after the change + occured. E.g., if the anomaly occured at 2023-07-05T08:15:00.000Z, request a + time window that starts at 2023-07-05T07:45:00.000Z and ends at + 2023-07-05T08:45:00.000Z. When comparing changes in different timeseries and + anomalies to determine a correlation, make sure to compare the timestamps. If + in doubt, rate the likelihood of them being related, given the time difference, + between 1 and 10. If below 5, assume it's not related. Mention this likelihood + (and the time difference) to the user. + + Your goal is to help the user determine the root cause of an issue quickly and + transparently. If you see a change or + anomaly in a metric for a service, try to find similar changes in the metrics + for the traffic to its downstream dependencies, by comparing transaction + metrics to span metrics. To inspect the traffic from one service to a + downstream dependency, first get the downstream dependencies for a service, + then get the span metrics from that service (\`service.name\`) to its + downstream dependency (\`span.destination.service.resource\`). For instance, + for an anomaly in throughput, first inspect \`transaction_throughput\` for + \`service.name\`. Then, inspect \`exit_span_throughput\` for its downstream + dependencies, by grouping by \`span.destination.service.resource\`. Repeat this + process over the next service its downstream dependencies until you identify a + root cause. If you can not find any similar changes, use correlations or + grouping to find attributes that could be causes for the change.`, + }); + }; +} diff --git a/x-pack/plugins/apm/server/plugin.ts b/x-pack/plugins/apm/server/plugin.ts index 07010d5f8dc5b4..9e1454a3e54943 100644 --- a/x-pack/plugins/apm/server/plugin.ts +++ b/x-pack/plugins/apm/server/plugin.ts @@ -49,6 +49,7 @@ import { scheduleSourceMapMigration } from './routes/source_maps/schedule_source import { createApmSourceMapIndexTemplate } from './routes/source_maps/create_apm_source_map_index_template'; import { addApiKeysToEveryPackagePolicyIfMissing } from './routes/fleet/api_keys/add_api_keys_to_policies_if_missing'; import { apmTutorialCustomIntegration } from '../common/tutorial/tutorials'; +import { registerAssistantFunctions } from './assistant_functions'; export class APMPlugin implements @@ -167,6 +168,8 @@ export class APMPlugin APM_SERVER_FEATURE_ID ); + const kibanaVersion = this.initContext.env.packageInfo.version; + registerRoutes({ core: { setup: core, @@ -179,7 +182,7 @@ export class APMPlugin ruleDataClient, plugins: resourcePlugins, telemetryUsageCounter, - kibanaVersion: this.initContext.env.packageInfo.version, + kibanaVersion, }); const { getApmIndices } = plugins.apmDataAccess; @@ -230,6 +233,18 @@ export class APMPlugin this.logger?.error(e); }); + plugins.observabilityAIAssistant.service.registration( + registerAssistantFunctions({ + config: this.currentConfig!, + coreSetup: core, + featureFlags: this.currentConfig!.featureFlags, + kibanaVersion, + logger: this.logger.get('assistant'), + plugins: resourcePlugins, + ruleDataClient, + }) + ); + return { config$ }; } diff --git a/x-pack/plugins/apm/server/routes/assistant_functions/route.ts b/x-pack/plugins/apm/server/routes/assistant_functions/route.ts index df7bb7e7a146dd..5aed2451858a95 100644 --- a/x-pack/plugins/apm/server/routes/assistant_functions/route.ts +++ b/x-pack/plugins/apm/server/routes/assistant_functions/route.ts @@ -4,21 +4,13 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import datemath from '@elastic/datemath'; import { ElasticsearchClient } from '@kbn/core/server'; import * as t from 'io-ts'; import { omit } from 'lodash'; -import { ApmDocumentType } from '../../../common/document_type'; -import { ENVIRONMENT_ALL } from '../../../common/environment_filter_values'; -import { RollupInterval } from '../../../common/rollup'; -import { ServiceHealthStatus } from '../../../common/service_health_status'; -import type { APMError } from '../../../typings/es_schemas/ui/apm_error'; import { getApmAlertsClient } from '../../lib/helpers/get_apm_alerts_client'; import { getApmEventClient } from '../../lib/helpers/get_apm_event_client'; import { getMlClient } from '../../lib/helpers/get_ml_client'; -import { getRandomSampler } from '../../lib/helpers/get_random_sampler'; import { createApmServerRoute } from '../apm_routes/create_apm_server_route'; -import { getServicesItems } from '../services/get_services/get_services_items'; import { CorrelationValue, correlationValuesRouteRt, @@ -29,7 +21,6 @@ import { getAssistantDownstreamDependencies, type APMDownstreamDependency, } from './get_apm_downstream_dependencies'; -import { errorRouteRt, getApmErrorDocument } from './get_apm_error_document'; import { getApmServiceSummary, serviceSummaryRouteRt, @@ -167,130 +158,9 @@ const getApmCorrelationValuesRoute = createApmServerRoute({ }, }); -const getApmErrorDocRoute = createApmServerRoute({ - endpoint: 'GET /internal/apm/assistant/get_error_document', - params: t.type({ - query: errorRouteRt, - }), - options: { - tags: ['access:apm'], - }, - handler: async ( - resources - ): Promise<{ content: Array> }> => { - const { params } = resources; - const apmEventClient = await getApmEventClient(resources); - const { query } = params; - - return { - content: await getApmErrorDocument({ - apmEventClient, - arguments: query, - }), - }; - }, -}); - -export interface ApmServicesListItem { - 'service.name': string; - 'agent.name'?: string; - 'transaction.type'?: string; - alertsCount: number; - healthStatus: ServiceHealthStatus; - 'service.environment'?: string[]; -} - -type ApmServicesListContent = ApmServicesListItem[]; - -const getApmServicesListRoute = createApmServerRoute({ - endpoint: 'POST /internal/apm/assistant/get_services_list', - params: t.type({ - body: t.intersection([ - t.type({ - start: t.string, - end: t.string, - }), - t.partial({ - 'service.environment': t.string, - healthStatus: t.array( - t.union([ - t.literal(ServiceHealthStatus.unknown), - t.literal(ServiceHealthStatus.healthy), - t.literal(ServiceHealthStatus.warning), - t.literal(ServiceHealthStatus.critical), - ]) - ), - }), - ]), - }), - options: { - tags: ['access:apm'], - }, - handler: async (resources): Promise<{ content: ApmServicesListContent }> => { - const { params } = resources; - const { body } = params; - - const { healthStatus } = body; - - const [apmEventClient, apmAlertsClient, mlClient, randomSampler] = - await Promise.all([ - getApmEventClient(resources), - getApmAlertsClient(resources), - getMlClient(resources), - getRandomSampler({ - security: resources.plugins.security, - probability: 1, - request: resources.request, - }), - ]); - - const start = datemath.parse(body.start)?.valueOf()!; - const end = datemath.parse(body.end)?.valueOf()!; - - const serviceItems = await getServicesItems({ - apmAlertsClient, - apmEventClient, - documentType: ApmDocumentType.TransactionMetric, - start, - end, - environment: body['service.environment'] || ENVIRONMENT_ALL.value, - kuery: '', - logger: resources.logger, - randomSampler, - rollupInterval: RollupInterval.OneMinute, - serviceGroup: null, - mlClient, - useDurationSummary: false, - }); - - let mappedItems = serviceItems.items.map((item): ApmServicesListItem => { - return { - 'service.name': item.serviceName, - 'agent.name': item.agentName, - alertsCount: item.alertsCount ?? 0, - healthStatus: item.healthStatus ?? ServiceHealthStatus.unknown, - 'service.environment': item.environments, - 'transaction.type': item.transactionType, - }; - }); - - if (healthStatus && healthStatus.length) { - mappedItems = mappedItems.filter((item): boolean => - healthStatus.includes(item.healthStatus) - ); - } - - return { - content: mappedItems, - }; - }, -}); - export const assistantRouteRepository = { ...getApmTimeSeriesRoute, ...getApmServiceSummaryRoute, - ...getApmErrorDocRoute, ...getApmCorrelationValuesRoute, ...getDownstreamDependenciesRoute, - ...getApmServicesListRoute, }; diff --git a/x-pack/plugins/apm/server/types.ts b/x-pack/plugins/apm/server/types.ts index 10687e37578497..fa77e52e768367 100644 --- a/x-pack/plugins/apm/server/types.ts +++ b/x-pack/plugins/apm/server/types.ts @@ -65,6 +65,10 @@ import { ProfilingDataAccessPluginSetup, ProfilingDataAccessPluginStart, } from '@kbn/profiling-data-access-plugin/server'; +import type { + ObservabilityAIAssistantPluginSetup, + ObservabilityAIAssistantPluginStart, +} from '@kbn/observability-ai-assistant-plugin/server'; import { APMConfig } from '.'; export interface APMPluginSetup { @@ -82,7 +86,7 @@ export interface APMPluginSetupDependencies { metricsDataAccess: MetricsDataPluginSetup; dataViews: {}; share: SharePluginSetup; - + observabilityAIAssistant: ObservabilityAIAssistantPluginSetup; // optional dependencies actions?: ActionsPlugin['setup']; alerting?: AlertingPlugin['setup']; @@ -108,7 +112,7 @@ export interface APMPluginStartDependencies { metricsDataAccess: MetricsDataPluginSetup; dataViews: DataViewsServerPluginStart; share: undefined; - + observabilityAIAssistant: ObservabilityAIAssistantPluginStart; // optional dependencies actions?: ActionsPlugin['start']; alerting?: AlertingPlugin['start']; diff --git a/x-pack/plugins/apm/public/utils/non_empty_string_ref.ts b/x-pack/plugins/apm/server/utils/non_empty_string_ref.ts similarity index 100% rename from x-pack/plugins/apm/public/utils/non_empty_string_ref.ts rename to x-pack/plugins/apm/server/utils/non_empty_string_ref.ts diff --git a/x-pack/plugins/observability_ai_assistant/common/conversation_complete.ts b/x-pack/plugins/observability_ai_assistant/common/conversation_complete.ts new file mode 100644 index 00000000000000..4002eec9d10d78 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/common/conversation_complete.ts @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +/* eslint-disable max-classes-per-file*/ +import { i18n } from '@kbn/i18n'; +import { Message } from './types'; + +export enum StreamingChatResponseEventType { + ChatCompletionChunk = 'chatCompletionChunk', + ConversationCreate = 'conversationCreate', + ConversationUpdate = 'conversationUpdate', + MessageAdd = 'messageAdd', + ConversationCompletionError = 'conversationCompletionError', +} + +type StreamingChatResponseEventBase< + TEventType extends StreamingChatResponseEventType, + TData extends {} +> = { + type: TEventType; +} & TData; + +type ChatCompletionChunkEvent = StreamingChatResponseEventBase< + StreamingChatResponseEventType.ChatCompletionChunk, + { + id: string; + message: { + content?: string; + function_call?: { + name?: string; + arguments?: string; + }; + }; + } +>; + +export type ConversationCreateEvent = StreamingChatResponseEventBase< + StreamingChatResponseEventType.ConversationCreate, + { + conversation: { + id: string; + title: string; + last_updated: string; + }; + } +>; + +export type ConversationUpdateEvent = StreamingChatResponseEventBase< + StreamingChatResponseEventType.ConversationUpdate, + { + conversation: { + id: string; + title: string; + last_updated: string; + }; + } +>; + +type MessageAddEvent = StreamingChatResponseEventBase< + StreamingChatResponseEventType.MessageAdd, + { message: Message; id: string } +>; + +type ConversationCompletionErrorEvent = StreamingChatResponseEventBase< + StreamingChatResponseEventType.ConversationCompletionError, + { error: { message: string; stack?: string; code?: ChatCompletionErrorCode } } +>; + +export type StreamingChatResponseEvent = + | ChatCompletionChunkEvent + | ConversationCreateEvent + | ConversationUpdateEvent + | MessageAddEvent + | ConversationCompletionErrorEvent; + +export enum ChatCompletionErrorCode { + InternalError = 'internalError', + NotFound = 'notFound', +} + +export class ConversationCompletionError extends Error { + code: ChatCompletionErrorCode; + + constructor(code: ChatCompletionErrorCode, message: string) { + super(message); + this.code = code; + } +} + +export class ConversationNotFoundError extends ConversationCompletionError { + constructor() { + super( + ChatCompletionErrorCode.NotFound, + i18n.translate( + 'xpack.observabilityAiAssistant.conversationCompletionError.conversationNotFound', + { + defaultMessage: 'Conversation not found', + } + ) + ); + } +} + +export function isChatCompletionError(error: Error): error is ConversationCompletionError { + return error instanceof ConversationCompletionError; +} diff --git a/x-pack/plugins/observability_ai_assistant/common/functions/lens.tsx b/x-pack/plugins/observability_ai_assistant/common/functions/lens.tsx new file mode 100644 index 00000000000000..a3d8487a83b8ae --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/common/functions/lens.tsx @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { FromSchema } from 'json-schema-to-ts'; +import { FIELD_FORMAT_IDS } from '@kbn/field-formats-plugin/common'; + +export enum SeriesType { + Bar = 'bar', + Line = 'line', + Area = 'area', + BarStacked = 'bar_stacked', + AreaStacked = 'area_stacked', + BarHorizontal = 'bar_horizontal', + BarPercentageStacked = 'bar_percentage_stacked', + AreaPercentageStacked = 'area_percentage_stacked', + BarHorizontalPercentageStacked = 'bar_horizontal_percentage_stacked', +} + +export const lensFunctionDefinition = { + name: 'lens', + contexts: ['core'], + description: + "Use this function to create custom visualizations, using Lens, that can be saved to dashboards. This function does not return data to the assistant, it only shows it to the user. When using this function, make sure to use the recall function to get more information about how to use it, with how you want to use it. Make sure the query also contains information about the user's request. The visualisation is displayed to the user above your reply, DO NOT try to generate or display an image yourself.", + descriptionForUser: + 'Use this function to create custom visualizations, using Lens, that can be saved to dashboards.', + parameters: { + type: 'object', + additionalProperties: false, + properties: { + layers: { + type: 'array', + items: { + type: 'object', + additionalProperties: false, + properties: { + label: { + type: 'string', + }, + formula: { + type: 'string', + description: + 'The formula for calculating the value, e.g. sum(my_field_name). Query the knowledge base to get more information about the syntax and available formulas.', + }, + filter: { + type: 'string', + description: 'A KQL query that will be used as a filter for the series', + }, + format: { + type: 'object', + additionalProperties: false, + properties: { + id: { + type: 'string', + description: + 'How to format the value. When using duration, make sure the value is seconds OR is converted to seconds using math functions. Ask the user for clarification in which unit the value is stored, or derive it from the field name.', + enum: [ + FIELD_FORMAT_IDS.BYTES, + FIELD_FORMAT_IDS.CURRENCY, + FIELD_FORMAT_IDS.DURATION, + FIELD_FORMAT_IDS.NUMBER, + FIELD_FORMAT_IDS.PERCENT, + FIELD_FORMAT_IDS.STRING, + ], + }, + }, + required: ['id'], + }, + }, + required: ['label', 'formula', 'format'], + }, + }, + timeField: { + type: 'string', + default: '@timefield', + description: + 'time field to use for XY chart. Use @timefield if its available on the index.', + }, + breakdown: { + type: 'object', + additionalProperties: false, + properties: { + field: { + type: 'string', + }, + }, + required: ['field'], + }, + indexPattern: { + type: 'string', + }, + seriesType: { + type: 'string', + enum: [ + SeriesType.Area, + SeriesType.AreaPercentageStacked, + SeriesType.AreaStacked, + SeriesType.Bar, + SeriesType.BarHorizontal, + SeriesType.BarHorizontalPercentageStacked, + SeriesType.BarPercentageStacked, + SeriesType.BarStacked, + SeriesType.Line, + ], + }, + start: { + type: 'string', + description: 'The start of the time range, in Elasticsearch datemath', + }, + end: { + type: 'string', + description: 'The end of the time range, in Elasticsearch datemath', + }, + }, + required: ['layers', 'indexPattern', 'start', 'end', 'timeField'], + } as const, +}; + +export type LensFunctionArguments = FromSchema; diff --git a/x-pack/plugins/observability_ai_assistant/common/types.ts b/x-pack/plugins/observability_ai_assistant/common/types.ts index 690b931271bb6e..b4a48622db3826 100644 --- a/x-pack/plugins/observability_ai_assistant/common/types.ts +++ b/x-pack/plugins/observability_ai_assistant/common/types.ts @@ -5,10 +5,20 @@ * 2.0. */ -import type { FromSchema } from 'json-schema-to-ts'; import type { JSONSchema } from 'json-schema-to-ts'; -import React from 'react'; -import { Observable } from 'rxjs'; +import type { + CreateChatCompletionResponse, + CreateChatCompletionResponseChoicesInner, +} from 'openai'; +import type { Observable } from 'rxjs'; + +export type CreateChatCompletionResponseChunk = Omit & { + choices: Array< + Omit & { + delta: { content?: string; function_call?: { name?: string; arguments?: string } }; + } + >; +}; export enum MessageRole { System = 'system', @@ -81,12 +91,12 @@ export interface ContextDefinition { description: string; } -type FunctionResponse = +export type FunctionResponse = | { content?: any; data?: any; } - | Observable; + | Observable; export enum FunctionVisibility { System = 'system', @@ -94,7 +104,9 @@ export enum FunctionVisibility { All = 'all', } -interface FunctionOptions { +export interface FunctionDefinition< + TParameters extends CompatibleJSONSchema = CompatibleJSONSchema +> { name: string; description: string; visibility?: FunctionVisibility; @@ -103,36 +115,7 @@ interface FunctionOptions = ( - options: { arguments: TArguments; messages: Message[]; connectorId: string }, - signal: AbortSignal -) => Promise; - -type RenderFunction = (options: { - arguments: TArguments; - response: TResponse; -}) => React.ReactNode; - -export interface FunctionDefinition { - options: FunctionOptions; - respond: ( - options: { arguments: any; messages: Message[]; connectorId: string }, - signal: AbortSignal - ) => Promise; - render?: RenderFunction; -} - export type RegisterContextDefinition = (options: ContextDefinition) => void; -export type RegisterFunctionDefinition = < - TParameters extends CompatibleJSONSchema, - TResponse extends FunctionResponse, - TArguments = FromSchema ->( - options: FunctionOptions, - respond: RespondFunction, - render?: RenderFunction -) => void; - export type ContextRegistry = Map; export type FunctionRegistry = Map; diff --git a/x-pack/plugins/observability_ai_assistant/common/utils/concatenate_openai_chunks.ts b/x-pack/plugins/observability_ai_assistant/common/utils/concatenate_openai_chunks.ts new file mode 100644 index 00000000000000..f15a193908a4e7 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/common/utils/concatenate_openai_chunks.ts @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { cloneDeep } from 'lodash'; +import { type Observable, scan } from 'rxjs'; +import { CreateChatCompletionResponseChunk, MessageRole } from '../types'; + +export const concatenateOpenAiChunks = + () => (source: Observable) => + source.pipe( + scan( + (acc, { choices }) => { + acc.message.content += choices[0].delta.content ?? ''; + acc.message.function_call.name += choices[0].delta.function_call?.name ?? ''; + acc.message.function_call.arguments += choices[0].delta.function_call?.arguments ?? ''; + return cloneDeep(acc); + }, + { + message: { + content: '', + function_call: { + name: '', + arguments: '', + trigger: MessageRole.Assistant as const, + }, + role: MessageRole.Assistant, + }, + } + ) + ); diff --git a/x-pack/plugins/observability_ai_assistant/common/utils/filter_function_definitions.ts b/x-pack/plugins/observability_ai_assistant/common/utils/filter_function_definitions.ts new file mode 100644 index 00000000000000..3de6c3bce2484c --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/common/utils/filter_function_definitions.ts @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { FunctionDefinition } from '../types'; + +export function filterFunctionDefinitions({ + contexts, + filter, + definitions, +}: { + contexts?: string[]; + filter?: string; + definitions: FunctionDefinition[]; +}) { + return contexts || filter + ? definitions.filter((fn) => { + const matchesContext = + !contexts || fn.contexts.some((context) => contexts.includes(context)); + const matchesFilter = + !filter || fn.name.includes(filter) || fn.description.includes(filter); + + return matchesContext && matchesFilter; + }) + : definitions; +} diff --git a/x-pack/plugins/observability_ai_assistant/common/utils/process_openai_stream.ts b/x-pack/plugins/observability_ai_assistant/common/utils/process_openai_stream.ts new file mode 100644 index 00000000000000..e0d66bbf851323 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/common/utils/process_openai_stream.ts @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +/* eslint-disable max-classes-per-file*/ +import { filter, map, Observable, tap } from 'rxjs'; +import type { CreateChatCompletionResponseChunk } from '../types'; + +class TokenLimitReachedError extends Error { + constructor() { + super(`Token limit reached`); + } +} + +class ServerError extends Error {} + +export function processOpenAiStream() { + return (source: Observable): Observable => + source.pipe( + map((line) => line.substring(6)), + filter((line) => !!line && line !== '[DONE]'), + map( + (line) => + JSON.parse(line) as CreateChatCompletionResponseChunk | { error: { message: string } } + ), + tap((line) => { + if ('error' in line) { + throw new ServerError(line.error.message); + } + if ( + 'choices' in line && + line.choices.length && + line.choices[0].finish_reason === 'length' + ) { + throw new TokenLimitReachedError(); + } + }), + filter( + (line): line is CreateChatCompletionResponseChunk => + 'object' in line && line.object === 'chat.completion.chunk' + ) + ); +} diff --git a/x-pack/plugins/observability_ai_assistant/jest.config.js b/x-pack/plugins/observability_ai_assistant/jest.config.js index 1d6798f6c7623c..ff54dbc08c2b0b 100644 --- a/x-pack/plugins/observability_ai_assistant/jest.config.js +++ b/x-pack/plugins/observability_ai_assistant/jest.config.js @@ -11,5 +11,8 @@ module.exports = { roots: ['/x-pack/plugins/observability_ai_assistant'], setupFiles: ['/x-pack/plugins/observability_ai_assistant/.storybook/jest_setup.js'], collectCoverage: true, + collectCoverageFrom: [ + '/x-pack/plugins/observability_ai_assistant/{common,public,server}/**/*.{js,ts,tsx}', + ], coverageReporters: ['html'], }; diff --git a/x-pack/plugins/observability_ai_assistant/kibana.jsonc b/x-pack/plugins/observability_ai_assistant/kibana.jsonc index 291c7e658de187..cd2d4b788bc78b 100644 --- a/x-pack/plugins/observability_ai_assistant/kibana.jsonc +++ b/x-pack/plugins/observability_ai_assistant/kibana.jsonc @@ -8,6 +8,7 @@ "browser": true, "configPath": ["xpack", "observabilityAIAssistant"], "requiredPlugins": [ + "alerting", "actions", "dataViews", "features", @@ -21,7 +22,7 @@ "triggersActionsUi", "dataViews" ], - "requiredBundles": ["fieldFormats", "kibanaReact", "kibanaUtils"], + "requiredBundles": [ "kibanaReact", "kibanaUtils"], "optionalPlugins": [], "extraPublicDirs": [] } diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx index afbb042d2eadcc..03587a6443a49b 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/chat_body.tsx @@ -74,7 +74,7 @@ export function ChatBody({ modelsManagementHref: string; currentUser?: Pick; startedFrom?: StartedFrom; - onConversationUpdate: (conversation: Conversation) => void; + onConversationUpdate: (conversation: { conversation: Conversation['conversation'] }) => void; }) { const license = useLicense(); const hasCorrectLicense = license?.hasAtLeast('enterprise'); diff --git a/x-pack/plugins/observability_ai_assistant/public/components/chat/function_list_popover.tsx b/x-pack/plugins/observability_ai_assistant/public/components/chat/function_list_popover.tsx index a82ece36739d99..d38d4f1e4b45f9 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/chat/function_list_popover.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/chat/function_list_popover.tsx @@ -172,12 +172,12 @@ function mapFunctions({ selectedFunctionName: string | undefined; }) { return functions - .filter((func) => func.options.visibility !== FunctionVisibility.System) + .filter((func) => func.visibility !== FunctionVisibility.System) .map((func) => ({ - label: func.options.name, - searchableLabel: func.options.descriptionForUser || func.options.description, + label: func.name, + searchableLabel: func.descriptionForUser || func.description, checked: - func.options.name === selectedFunctionName + func.name === selectedFunctionName ? ('on' as EuiSelectableOptionCheckedType) : ('off' as EuiSelectableOptionCheckedType), })); diff --git a/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx b/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx index 48ba86a98fbe8f..a62ec1d5fa29ff 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx @@ -43,6 +43,7 @@ function ChatContent({ chatService, connectorId, initialMessages, + persist: false, }); const lastAssistantResponse = last( diff --git a/x-pack/plugins/observability_ai_assistant/public/functions/alerts.ts b/x-pack/plugins/observability_ai_assistant/public/functions/alerts.ts deleted file mode 100644 index b4c0d0bd1bdfd1..00000000000000 --- a/x-pack/plugins/observability_ai_assistant/public/functions/alerts.ts +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import type { RegisterFunctionDefinition } from '../../common/types'; -import type { ObservabilityAIAssistantService } from '../types'; - -const DEFAULT_FEATURE_IDS = [ - 'apm', - 'infrastructure', - 'logs', - 'uptime', - 'slo', - 'observability', -] as const; - -export function registerAlertsFunction({ - service, - registerFunction, -}: { - service: ObservabilityAIAssistantService; - registerFunction: RegisterFunctionDefinition; -}) { - registerFunction( - { - name: 'alerts', - contexts: ['core'], - description: - 'Get alerts for Observability. Display the response in tabular format if appropriate.', - descriptionForUser: 'Get alerts for Observability', - parameters: { - type: 'object', - additionalProperties: false, - properties: { - featureIds: { - type: 'array', - additionalItems: false, - items: { - type: 'string', - enum: DEFAULT_FEATURE_IDS, - }, - description: - 'The Observability apps for which to retrieve alerts. By default it will return alerts for all apps.', - }, - start: { - type: 'string', - description: 'The start of the time range, in Elasticsearch date math, like `now`.', - }, - end: { - type: 'string', - description: 'The end of the time range, in Elasticsearch date math, like `now-24h`.', - }, - filter: { - type: 'string', - description: - 'a KQL query to filter the data by. If no filter should be applied, leave it empty.', - }, - includeRecovered: { - type: 'boolean', - description: - 'Whether to include recovered/closed alerts. Defaults to false, which means only active alerts will be returned', - }, - }, - required: ['start', 'end'], - } as const, - }, - ({ arguments: { start, end, featureIds, filter, includeRecovered } }, signal) => { - return service.callApi('POST /internal/observability_ai_assistant/functions/alerts', { - params: { - body: { - start, - end, - featureIds: - featureIds && featureIds.length > 0 ? featureIds : DEFAULT_FEATURE_IDS.concat(), - filter, - includeRecovered, - }, - }, - signal, - }); - } - ); -} diff --git a/x-pack/plugins/observability_ai_assistant/public/functions/get_dataset_info.ts b/x-pack/plugins/observability_ai_assistant/public/functions/get_dataset_info.ts deleted file mode 100644 index cbb6167cf684ea..00000000000000 --- a/x-pack/plugins/observability_ai_assistant/public/functions/get_dataset_info.ts +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { chunk, groupBy, uniq } from 'lodash'; -import { CreateChatCompletionResponse } from 'openai'; -import { FunctionVisibility, MessageRole, RegisterFunctionDefinition } from '../../common/types'; -import type { ObservabilityAIAssistantService } from '../types'; - -export function registerGetDatasetInfoFunction({ - service, - registerFunction, -}: { - service: ObservabilityAIAssistantService; - registerFunction: RegisterFunctionDefinition; -}) { - registerFunction( - { - name: 'get_dataset_info', - contexts: ['core'], - visibility: FunctionVisibility.System, - description: `Use this function to get information about indices/datasets available and the fields available on them. - - providing empty string as index name will retrieve all indices - else list of all fields for the given index will be given. if no fields are returned this means no indices were matched by provided index pattern. - wildcards can be part of index name.`, - descriptionForUser: - 'This function allows the assistant to get information about available indices and their fields.', - parameters: { - type: 'object', - additionalProperties: false, - properties: { - index: { - type: 'string', - description: - 'index pattern the user is interested in or empty string to get information about all available indices', - }, - }, - required: ['index'], - } as const, - }, - async ({ arguments: { index }, messages, connectorId }, signal) => { - const response = await service.callApi( - 'POST /internal/observability_ai_assistant/functions/get_dataset_info', - { - params: { - body: { - index, - }, - }, - signal, - } - ); - - const allFields = response.fields; - - const fieldNames = uniq(allFields.map((field) => field.name)); - - const groupedFields = groupBy(allFields, (field) => field.name); - - const relevantFields = await Promise.all( - chunk(fieldNames, 500).map(async (fieldsInChunk) => { - const chunkResponse = (await service.callApi( - 'POST /internal/observability_ai_assistant/chat', - { - signal, - params: { - query: { - stream: false, - }, - body: { - connectorId, - messages: [ - { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.System, - content: `You are a helpful assistant for Elastic Observability. - Your task is to create a list of field names that are relevant - to the conversation, using ONLY the list of fields and - types provided in the last user message. DO NOT UNDER ANY - CIRCUMSTANCES include fields not mentioned in this list.`, - }, - }, - ...messages.slice(1), - { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.User, - content: `This is the list: - - ${fieldsInChunk.join('\n')}`, - }, - }, - ], - functions: [ - { - name: 'fields', - description: 'The fields you consider relevant to the conversation', - parameters: { - type: 'object', - additionalProperties: false, - properties: { - fields: { - type: 'array', - additionalProperties: false, - addditionalItems: false, - items: { - type: 'string', - additionalProperties: false, - addditionalItems: false, - }, - }, - }, - required: ['fields'], - }, - }, - ], - functionCall: 'fields', - }, - }, - } - )) as CreateChatCompletionResponse; - - return chunkResponse.choices[0].message?.function_call?.arguments - ? ( - JSON.parse(chunkResponse.choices[0].message?.function_call?.arguments) as { - fields: string[]; - } - ).fields - .filter((field) => fieldNames.includes(field)) - .map((field) => { - const fieldDescriptors = groupedFields[field]; - return `${field}:${fieldDescriptors - .map((descriptor) => descriptor.type) - .join(',')}`; - }) - : [chunkResponse.choices[0].message?.content ?? '']; - }) - ); - - return { - content: { - indices: response.indices, - fields: relevantFields.flat(), - }, - }; - } - ); -} diff --git a/x-pack/plugins/observability_ai_assistant/public/functions/index.ts b/x-pack/plugins/observability_ai_assistant/public/functions/index.ts index 97c311cfac0690..4bdbb31b15dcc6 100644 --- a/x-pack/plugins/observability_ai_assistant/public/functions/index.ts +++ b/x-pack/plugins/observability_ai_assistant/public/functions/index.ts @@ -5,88 +5,21 @@ * 2.0. */ -import dedent from 'dedent'; -import type { CoreStart } from '@kbn/core/public'; -import type { RegisterContextDefinition, RegisterFunctionDefinition } from '../../common/types'; -import type { ObservabilityAIAssistantPluginStartDependencies } from '../types'; -import type { ObservabilityAIAssistantService } from '../types'; -import { registerElasticsearchFunction } from './elasticsearch'; -import { registerKibanaFunction } from './kibana'; -import { registerLensFunction } from './lens'; -import { registerRecallFunction } from './recall'; -import { registerGetDatasetInfoFunction } from './get_dataset_info'; -import { registerSummarizationFunction } from './summarize'; -import { registerAlertsFunction } from './alerts'; -import { registerEsqlFunction } from './esql'; +import type { + ObservabilityAIAssistantPluginStartDependencies, + ObservabilityAIAssistantService, + RegisterRenderFunctionDefinition, +} from '../types'; +import { registerLensRenderFunction } from './lens'; export async function registerFunctions({ - registerFunction, - registerContext, + registerRenderFunction, service, pluginsStart, - coreStart, - signal, }: { - registerFunction: RegisterFunctionDefinition; - registerContext: RegisterContextDefinition; + registerRenderFunction: RegisterRenderFunctionDefinition; service: ObservabilityAIAssistantService; pluginsStart: ObservabilityAIAssistantPluginStartDependencies; - coreStart: CoreStart; - signal: AbortSignal; }) { - return service - .callApi('GET /internal/observability_ai_assistant/functions/kb_status', { - signal, - }) - .then((response) => { - const isReady = response.ready; - - let description = dedent( - `You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observability users to quickly assess what is happening in their observed systems. You can help them visualise and analyze data, investigate their systems, perform root cause analysis or identify optimisation opportunities. - - It's very important to not assume what the user is meaning. Ask them for clarification if needed. - - If you are unsure about which function should be used and with what arguments, ask the user for clarification or confirmation. - - In KQL, escaping happens with double quotes, not single quotes. Some characters that need escaping are: ':()\\\ - /\". Always put a field value in double quotes. Best: service.name:\"opbeans-go\". Wrong: service.name:opbeans-go. This is very important! - - You can use Github-flavored Markdown in your responses. If a function returns an array, consider using a Markdown table to format the response. - - If multiple functions are suitable, use the most specific and easy one. E.g., when the user asks to visualise APM data, use the APM functions (if available) rather than Lens. - - If a function call fails, DO NOT UNDER ANY CIRCUMSTANCES execute it again. Ask the user for guidance and offer them options. - - Note that ES|QL (the Elasticsearch query language, which is NOT Elasticsearch SQL, but a new piped language) is the preferred query language. - - If the user asks about a query, or ES|QL, always call the "esql" function. DO NOT UNDER ANY CIRCUMSTANCES generate ES|QL queries yourself. Even if the "recall" function was used before that, follow it up with the "esql" function.` - ); - - if (isReady) { - description += `You can use the "summarize" functions to store new information you have learned in a knowledge database. Once you have established that you did not know the answer to a question, and the user gave you this information, it's important that you create a summarisation of what you have learned and store it in the knowledge database. Don't create a new summarization if you see a similar summarization in the conversation, instead, update the existing one by re-using its ID. - - Additionally, you can use the "recall" function to retrieve relevant information from the knowledge database. - `; - - description += `Here are principles you MUST adhere to, in order: - - DO NOT make any assumptions about where and how users have stored their data. ALWAYS first call get_dataset_info function with empty string to get information about available indices. Once you know about available indices you MUST use this function again to get a list of available fields for specific index. If user provides an index name make sure its a valid index first before using it to retrieve the field list by calling this function with an empty string! - `; - registerSummarizationFunction({ service, registerFunction }); - registerRecallFunction({ service, registerFunction }); - registerLensFunction({ service, pluginsStart, registerFunction }); - } else { - description += `You do not have a working memory. Don't try to recall information via the "recall" function. If the user expects you to remember the previous conversations, tell them they can set up the knowledge base. A banner is available at the top of the conversation to set this up.`; - } - - registerElasticsearchFunction({ service, registerFunction }); - registerEsqlFunction({ service, registerFunction }); - registerKibanaFunction({ service, registerFunction, coreStart }); - registerAlertsFunction({ service, registerFunction }); - registerGetDatasetInfoFunction({ service, registerFunction }); - - registerContext({ - name: 'core', - description: dedent(description), - }); - }); + registerLensRenderFunction({ service, pluginsStart, registerRenderFunction }); } diff --git a/x-pack/plugins/observability_ai_assistant/public/functions/kibana.ts b/x-pack/plugins/observability_ai_assistant/public/functions/kibana.ts deleted file mode 100644 index a47acdb02d4338..00000000000000 --- a/x-pack/plugins/observability_ai_assistant/public/functions/kibana.ts +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import type { CoreStart } from '@kbn/core/public'; -import type { RegisterFunctionDefinition } from '../../common/types'; -import type { ObservabilityAIAssistantService } from '../types'; - -export function registerKibanaFunction({ - service, - registerFunction, - coreStart, -}: { - service: ObservabilityAIAssistantService; - registerFunction: RegisterFunctionDefinition; - coreStart: CoreStart; -}) { - registerFunction( - { - name: 'kibana', - contexts: ['core'], - description: - 'Call Kibana APIs on behalf of the user. Only call this function when the user has explicitly requested it, and you know how to call it, for example by querying the knowledge base or having the user explain it to you. Assume that pathnames, bodies and query parameters may have changed since your knowledge cut off date.', - descriptionForUser: 'Call Kibana APIs on behalf of the user', - parameters: { - type: 'object', - additionalProperties: false, - properties: { - method: { - type: 'string', - description: 'The HTTP method of the Kibana endpoint', - enum: ['GET', 'PUT', 'POST', 'DELETE', 'PATCH'] as const, - }, - pathname: { - type: 'string', - description: 'The pathname of the Kibana endpoint, excluding query parameters', - }, - query: { - type: 'object', - description: 'The query parameters, as an object', - additionalProperties: { - type: 'string', - }, - }, - body: { - type: 'object', - description: 'The body of the request', - }, - }, - required: ['method', 'pathname'] as const, - }, - }, - ({ arguments: { method, pathname, body, query } }, signal) => { - return coreStart.http - .fetch(pathname, { - method, - body: body ? JSON.stringify(body) : undefined, - query, - signal, - }) - .then((response) => { - return { content: response }; - }); - } - ); -} diff --git a/x-pack/plugins/observability_ai_assistant/public/functions/lens.tsx b/x-pack/plugins/observability_ai_assistant/public/functions/lens.tsx index fa8dc66dce2941..0d694a532a8056 100644 --- a/x-pack/plugins/observability_ai_assistant/public/functions/lens.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/functions/lens.tsx @@ -6,17 +6,18 @@ */ import { EuiButton, EuiFlexGroup, EuiFlexItem, EuiLoadingSpinner } from '@elastic/eui'; import type { DataViewsServicePublic } from '@kbn/data-views-plugin/public/types'; -import { FIELD_FORMAT_IDS } from '@kbn/field-formats-plugin/common'; +import { i18n } from '@kbn/i18n'; import { LensAttributesBuilder, XYChart, XYDataLayer } from '@kbn/lens-embeddable-utils'; import type { LensEmbeddableInput, LensPublicStart } from '@kbn/lens-plugin/public'; import React, { useState } from 'react'; import useAsync from 'react-use/lib/useAsync'; -import { i18n } from '@kbn/i18n'; import { Assign } from 'utility-types'; -import type { RegisterFunctionDefinition } from '../../common/types'; +import { LensFunctionArguments } from '../../common/functions/lens'; import type { ObservabilityAIAssistantPluginStartDependencies, ObservabilityAIAssistantService, + RegisterRenderFunctionDefinition, + RenderFunction, } from '../types'; export enum SeriesType { @@ -137,120 +138,20 @@ function Lens({ ); } -export function registerLensFunction({ +export function registerLensRenderFunction({ service, - registerFunction, + registerRenderFunction, pluginsStart, }: { service: ObservabilityAIAssistantService; - registerFunction: RegisterFunctionDefinition; + registerRenderFunction: RegisterRenderFunctionDefinition; pluginsStart: ObservabilityAIAssistantPluginStartDependencies; }) { - registerFunction( - { - name: 'lens', - contexts: ['core'], - description: - "Use this function to create custom visualizations, using Lens, that can be saved to dashboards. This function does not return data to the assistant, it only shows it to the user. When using this function, make sure to use the recall function to get more information about how to use it, with how you want to use it. Make sure the query also contains information about the user's request. The visualisation is displayed to the user above your reply, DO NOT try to generate or display an image yourself.", - descriptionForUser: - 'Use this function to create custom visualizations, using Lens, that can be saved to dashboards.', - parameters: { - type: 'object', - additionalProperties: false, - properties: { - layers: { - type: 'array', - items: { - type: 'object', - additionalProperties: false, - properties: { - label: { - type: 'string', - }, - formula: { - type: 'string', - description: - 'The formula for calculating the value, e.g. sum(my_field_name). Query the knowledge base to get more information about the syntax and available formulas.', - }, - filter: { - type: 'string', - description: 'A KQL query that will be used as a filter for the series', - }, - format: { - type: 'object', - additionalProperties: false, - properties: { - id: { - type: 'string', - description: - 'How to format the value. When using duration, make sure the value is seconds OR is converted to seconds using math functions. Ask the user for clarification in which unit the value is stored, or derive it from the field name.', - enum: [ - FIELD_FORMAT_IDS.BYTES, - FIELD_FORMAT_IDS.CURRENCY, - FIELD_FORMAT_IDS.DURATION, - FIELD_FORMAT_IDS.NUMBER, - FIELD_FORMAT_IDS.PERCENT, - FIELD_FORMAT_IDS.STRING, - ], - }, - }, - required: ['id'], - }, - }, - required: ['label', 'formula', 'format'], - }, - }, - timeField: { - type: 'string', - default: '@timefield', - description: - 'time field to use for XY chart. Use @timefield if its available on the index.', - }, - breakdown: { - type: 'object', - additionalProperties: false, - properties: { - field: { - type: 'string', - }, - }, - required: ['field'], - }, - indexPattern: { - type: 'string', - }, - seriesType: { - type: 'string', - enum: [ - SeriesType.Area, - SeriesType.AreaPercentageStacked, - SeriesType.AreaStacked, - SeriesType.Bar, - SeriesType.BarHorizontal, - SeriesType.BarHorizontalPercentageStacked, - SeriesType.BarPercentageStacked, - SeriesType.BarStacked, - SeriesType.Line, - ], - }, - start: { - type: 'string', - description: 'The start of the time range, in Elasticsearch datemath', - }, - end: { - type: 'string', - description: 'The end of the time range, in Elasticsearch datemath', - }, - }, - required: ['layers', 'indexPattern', 'start', 'end', 'timeField'], - } as const, - }, - async () => { - return { - content: {}, - }; - }, - ({ arguments: { layers, indexPattern, breakdown, seriesType, start, end, timeField } }) => { + registerRenderFunction( + 'lens', + ({ + arguments: { layers, indexPattern, breakdown, seriesType, start, end, timeField }, + }: Parameters>[0]) => { const xyDataLayer = new XYDataLayer({ data: layers.map((layer) => ({ type: 'formula', diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts index 22bc997a4c925b..1c77be78d591b4 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts @@ -8,7 +8,13 @@ import type { DeeplyMockedKeys } from '@kbn/utility-types-jest'; import { type RenderHookResult, renderHook, act } from '@testing-library/react-hooks'; import { Subject } from 'rxjs'; import { MessageRole } from '../../common'; -import type { ObservabilityAIAssistantChatService, PendingMessage } from '../types'; +import { + ChatCompletionErrorCode, + ConversationCompletionError, + StreamingChatResponseEvent, + StreamingChatResponseEventType, +} from '../../common/conversation_complete'; +import type { ObservabilityAIAssistantChatService } from '../types'; import { type UseChatResult, useChat, type UseChatProps, ChatState } from './use_chat'; import * as useKibanaModule from './use_kibana'; @@ -16,7 +22,7 @@ type MockedChatService = DeeplyMockedKeys; const mockChatService: MockedChatService = { chat: jest.fn(), - executeFunction: jest.fn(), + complete: jest.fn(), getContexts: jest.fn().mockReturnValue([{ name: 'core', description: '' }]), getFunctions: jest.fn().mockReturnValue([]), hasFunction: jest.fn().mockReturnValue(false), @@ -58,6 +64,7 @@ describe('useChat', () => { }, }, ], + persist: false, } as UseChatProps, }); }); @@ -75,7 +82,7 @@ describe('useChat', () => { }); describe('when calling next()', () => { - let subject: Subject; + let subject: Subject; beforeEach(() => { hookResult = renderHook(useChat, { @@ -83,12 +90,13 @@ describe('useChat', () => { connectorId: 'my-connector', chatService: mockChatService, initialMessages: [], + persist: false, } as UseChatProps, }); subject = new Subject(); - mockChatService.chat.mockReturnValueOnce(subject); + mockChatService.complete.mockReturnValueOnce(subject); act(() => { hookResult.result.current.next([ @@ -113,11 +121,23 @@ describe('useChat', () => { act(() => { hookResult.result.current.next([]); subject.next({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + id: 'my-message-id', message: { - role: MessageRole.User, content: 'goodbye', }, }); + subject.next({ + type: StreamingChatResponseEventType.MessageAdd, + id: 'my-message-id', + message: { + '@timestamp': new Date().toISOString(), + message: { + content: 'goodbye', + role: MessageRole.Assistant, + }, + }, + }); subject.complete(); }); }); @@ -136,9 +156,10 @@ describe('useChat', () => { it('updates the returned messages', () => { act(() => { subject.next({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + id: 'my-message-id', message: { content: 'good', - role: MessageRole.Assistant, }, }); }); @@ -151,15 +172,28 @@ describe('useChat', () => { it('updates the returned messages and the loading state', () => { act(() => { subject.next({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + id: 'my-message-id', message: { content: 'good', - role: MessageRole.Assistant, }, }); subject.next({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + id: 'my-message-id', message: { - content: 'goodbye', - role: MessageRole.Assistant, + content: 'bye', + }, + }); + subject.next({ + type: StreamingChatResponseEventType.MessageAdd, + id: 'my-message-id', + message: { + '@timestamp': new Date().toISOString(), + message: { + content: 'goodbye', + role: MessageRole.Assistant, + }, }, }); subject.complete(); @@ -174,13 +208,13 @@ describe('useChat', () => { beforeEach(() => { act(() => { subject.next({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + id: 'my-message-id', message: { content: 'good', - role: MessageRole.Assistant, }, - aborted: true, }); - subject.complete(); + hookResult.result.current.stop(); }); }); @@ -198,13 +232,15 @@ describe('useChat', () => { beforeEach(() => { act(() => { subject.next({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + id: 'my-message-id', message: { content: 'good', - role: MessageRole.Assistant, }, - error: new Error('foo'), }); - subject.complete(); + subject.error( + new ConversationCompletionError(ChatCompletionErrorCode.InternalError, 'foo') + ); }); }); @@ -217,248 +253,5 @@ describe('useChat', () => { expect(addErrorMock).toHaveBeenCalled(); }); }); - - describe('after the LLM responds with a function call', () => { - let resolve: (data: any) => void; - let reject: (error: Error) => void; - - beforeEach(() => { - mockChatService.executeFunction.mockResolvedValueOnce( - new Promise((...args) => { - resolve = args[0]; - reject = args[1]; - }) - ); - - act(() => { - subject.next({ - message: { - content: '', - role: MessageRole.Assistant, - function_call: { - name: 'my_function', - arguments: JSON.stringify({ foo: 'bar' }), - trigger: MessageRole.Assistant, - }, - }, - }); - subject.complete(); - }); - }); - - it('the chat state stays loading', () => { - expect(hookResult.result.current.state).toBe(ChatState.Loading); - }); - - it('adds a message', () => { - const { messages } = hookResult.result.current; - - expect(messages.length).toBe(3); - expect(messages[2]).toEqual({ - '@timestamp': expect.any(String), - message: { - content: '', - function_call: { - arguments: JSON.stringify({ foo: 'bar' }), - name: 'my_function', - trigger: MessageRole.Assistant, - }, - role: MessageRole.Assistant, - }, - }); - }); - - describe('the function call succeeds', () => { - beforeEach(async () => { - subject = new Subject(); - mockChatService.chat.mockReturnValueOnce(subject); - - await act(async () => { - resolve({ content: { foo: 'bar' }, data: { bar: 'foo' } }); - }); - }); - - it('adds a message', () => { - const { messages } = hookResult.result.current; - - expect(messages.length).toBe(4); - expect(messages[3]).toEqual({ - '@timestamp': expect.any(String), - message: { - content: JSON.stringify({ foo: 'bar' }), - data: JSON.stringify({ bar: 'foo' }), - name: 'my_function', - role: MessageRole.User, - }, - }); - }); - - it('keeps the chat state in loading', () => { - expect(hookResult.result.current.state).toBe(ChatState.Loading); - }); - it('sends the function call back to the LLM for a response', () => { - expect(mockChatService.chat).toHaveBeenCalledTimes(2); - expect(mockChatService.chat).toHaveBeenLastCalledWith({ - connectorId: 'my-connector', - messages: hookResult.result.current.messages, - }); - }); - }); - - describe('the function call fails', () => { - beforeEach(async () => { - subject = new Subject(); - mockChatService.chat.mockReturnValue(subject); - - await act(async () => { - reject(new Error('connection error')); - }); - }); - - it('keeps the chat state in loading', () => { - expect(hookResult.result.current.state).toBe(ChatState.Loading); - }); - - it('adds a message', () => { - const { messages } = hookResult.result.current; - - expect(messages.length).toBe(4); - expect(messages[3]).toEqual({ - '@timestamp': expect.any(String), - message: { - content: JSON.stringify({ - message: 'Error: connection error', - error: {}, - }), - name: 'my_function', - role: MessageRole.User, - }, - }); - }); - - it('does not show an error toast', () => { - expect(addErrorMock).not.toHaveBeenCalled(); - }); - - it('sends the function call back to the LLM for a response', () => { - expect(mockChatService.chat).toHaveBeenCalledTimes(2); - expect(mockChatService.chat).toHaveBeenLastCalledWith({ - connectorId: 'my-connector', - messages: hookResult.result.current.messages, - }); - }); - }); - - describe('stop() is called', () => { - beforeEach(() => { - act(() => { - hookResult.result.current.stop(); - }); - }); - - it('sets the chatState to aborted', () => { - expect(hookResult.result.current.state).toBe(ChatState.Aborted); - }); - - it('has called the abort controller', () => { - const signal = mockChatService.executeFunction.mock.calls[0][0].signal; - - expect(signal.aborted).toBe(true); - }); - - it('is not updated after the promise is rejected', () => { - const numRenders = hookResult.result.all.length; - - act(() => { - reject(new Error('Request aborted')); - }); - - expect(numRenders).toBe(hookResult.result.all.length); - }); - - it('removes all subscribers', () => { - expect(subject.observed).toBe(false); - }); - }); - - describe('setMessages() is called', () => {}); - }); - }); - - describe('when calling next() with the recall function available', () => { - let subject: Subject; - - beforeEach(async () => { - hookResult = renderHook(useChat, { - initialProps: { - connectorId: 'my-connector', - chatService: mockChatService, - initialMessages: [], - } as UseChatProps, - }); - - subject = new Subject(); - - mockChatService.hasFunction.mockReturnValue(true); - mockChatService.executeFunction.mockResolvedValueOnce({ - content: [ - { - id: 'my_document', - text: 'My text', - }, - ], - }); - - mockChatService.chat.mockReturnValueOnce(subject); - - await act(async () => { - hookResult.result.current.next([ - ...hookResult.result.current.messages, - { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.User, - content: 'hello', - }, - }, - ]); - }); - }); - - it('adds a user message and a recall function request', () => { - expect(hookResult.result.current.messages[1].message.content).toBe('hello'); - expect(hookResult.result.current.messages[2].message.function_call?.name).toBe('recall'); - expect(hookResult.result.current.messages[2].message.content).toBe(''); - expect(hookResult.result.current.messages[2].message.function_call?.arguments).toBe( - JSON.stringify({ queries: [], contexts: [] }) - ); - expect(hookResult.result.current.messages[3].message.name).toBe('recall'); - expect(hookResult.result.current.messages[3].message.content).toBe( - JSON.stringify([ - { - id: 'my_document', - text: 'My text', - }, - ]) - ); - }); - - it('executes the recall function', () => { - expect(mockChatService.executeFunction).toHaveBeenCalled(); - expect(mockChatService.executeFunction).toHaveBeenCalledWith({ - signal: expect.any(AbortSignal), - connectorId: 'my-connector', - args: JSON.stringify({ queries: [], contexts: [] }), - name: 'recall', - messages: [...hookResult.result.current.messages.slice(0, -1)], - }); - }); - - it('sends the user message, function request and recall response to the LLM', () => { - expect(mockChatService.chat).toHaveBeenCalledWith({ - connectorId: 'my-connector', - messages: [...hookResult.result.current.messages], - }); - }); }); }); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts index aeef36127f6c4b..989b3fdcb23a84 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts @@ -6,12 +6,16 @@ */ import { i18n } from '@kbn/i18n'; -import { last } from 'lodash'; +import { merge } from 'lodash'; import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; -import { isObservable } from 'rxjs'; -import { type Message, MessageRole } from '../../common'; +import { MessageRole, type Message } from '../../common'; +import { + ConversationCreateEvent, + ConversationUpdateEvent, + StreamingChatResponseEventType, +} from '../../common/conversation_complete'; import { getAssistantSetupMessage } from '../service/get_assistant_setup_message'; -import type { ObservabilityAIAssistantChatService, PendingMessage } from '../types'; +import type { ObservabilityAIAssistantChatService } from '../types'; import { useKibana } from './use_kibana'; import { useOnce } from './use_once'; @@ -22,6 +26,13 @@ export enum ChatState { Aborted = 'aborted', } +function getWithSystemMessage(messages: Message[], systemMessage: Message) { + return [ + systemMessage, + ...messages.filter((message) => message.message.role !== MessageRole.System), + ]; +} + export interface UseChatResult { messages: Message[]; setMessages: (messages: Message[]) => void; @@ -32,16 +43,22 @@ export interface UseChatResult { export interface UseChatProps { initialMessages: Message[]; + initialConversationId?: string; chatService: ObservabilityAIAssistantChatService; connectorId?: string; + persist: boolean; + onConversationUpdate?: (event: ConversationCreateEvent | ConversationUpdateEvent) => void; onChatComplete?: (messages: Message[]) => void; } export function useChat({ initialMessages, + initialConversationId: initialConversationIdFromProps, chatService, connectorId, + onConversationUpdate, onChatComplete, + persist, }: UseChatProps): UseChatResult { const [chatState, setChatState] = useState(ChatState.Ready); @@ -51,9 +68,11 @@ export function useChat({ useOnce(initialMessages); + const initialConversationId = useOnce(initialConversationIdFromProps); + const [messages, setMessages] = useState(initialMessages); - const [pendingMessage, setPendingMessage] = useState(); + const [pendingMessages, setPendingMessages] = useState(); const abortControllerRef = useRef(new AbortController()); @@ -62,13 +81,27 @@ export function useChat({ } = useKibana(); const onChatCompleteRef = useRef(onChatComplete); - onChatCompleteRef.current = onChatComplete; + const onConversationUpdateRef = useRef(onConversationUpdate); + onConversationUpdateRef.current = onConversationUpdate; + const handleSignalAbort = useCallback(() => { setChatState(ChatState.Aborted); }, []); + const handleError = useCallback( + (error: Error) => { + notifications.toasts.addError(error, { + title: i18n.translate('xpack.observabilityAiAssistant.failedToLoadResponse', { + defaultMessage: 'Failed to load response from the AI Assistant', + }), + }); + setChatState(ChatState.Error); + }, + [notifications.toasts] + ); + const next = useCallback( async (nextMessages: Message[]) => { // make sure we ignore any aborts for the previous signal @@ -77,173 +110,134 @@ export function useChat({ // cancel running requests abortControllerRef.current.abort(); - const lastMessage = last(nextMessages); + abortControllerRef.current = new AbortController(); - const allMessages = [ - systemMessage, - ...nextMessages.filter((message) => message.message.role !== MessageRole.System), - ]; + setPendingMessages([]); + setMessages(nextMessages); - setMessages(allMessages); - - if (!lastMessage || !connectorId) { + if (!connectorId || !nextMessages.length) { setChatState(ChatState.Ready); - onChatCompleteRef.current?.(nextMessages); return; } - const isUserMessage = lastMessage.message.role === MessageRole.User; - const functionCall = lastMessage.message.function_call; - const isAssistantMessageWithFunctionRequest = - lastMessage.message.role === MessageRole.Assistant && functionCall && !!functionCall.name; - - const isFunctionResult = isUserMessage && !!lastMessage.message.name; - - const isRecallFunctionAvailable = chatService.hasFunction('recall'); - - if (!isUserMessage && !isAssistantMessageWithFunctionRequest) { - setChatState(ChatState.Ready); - onChatCompleteRef.current?.(nextMessages); - return; - } - - const abortController = (abortControllerRef.current = new AbortController()); - - abortController.signal.addEventListener('abort', handleSignalAbort); - setChatState(ChatState.Loading); - if (isUserMessage && !isFunctionResult && isRecallFunctionAvailable) { - const allMessagesWithRecall = allMessages.concat({ - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.Assistant, - content: '', - function_call: { - name: 'recall', - arguments: JSON.stringify({ queries: [], contexts: [] }), - trigger: MessageRole.Assistant, - }, - }, - }); - next(allMessagesWithRecall); - return; - } - - function handleError(error: Error) { - setChatState(ChatState.Error); - notifications.toasts.addError(error, { - title: i18n.translate('xpack.observabilityAiAssistant.failedToLoadResponse', { - defaultMessage: 'Failed to load response from the AI Assistant', - }), - }); - } - - const response = isAssistantMessageWithFunctionRequest - ? await chatService - .executeFunction({ - name: functionCall.name, - signal: abortController.signal, - args: functionCall.arguments, - connectorId, - messages: allMessages, - }) - .catch((error) => { - return { - content: { - message: error.toString(), - error, - }, - data: undefined, - }; - }) - : chatService.chat({ - messages: allMessages, - connectorId, - }); - - if (abortController.signal.aborted) { - return; + const next$ = chatService.complete({ + connectorId, + messages: getWithSystemMessage(nextMessages, systemMessage), + persist, + signal: abortControllerRef.current.signal, + conversationId: initialConversationId, + }); + + function getPendingMessages() { + return [ + ...completedMessages, + ...(pendingMessage + ? [ + merge( + { + message: { + role: MessageRole.Assistant, + function_call: { trigger: MessageRole.Assistant as const }, + }, + }, + pendingMessage + ), + ] + : []), + ]; } - if (isObservable(response)) { - let localPendingMessage: PendingMessage = { - message: { - content: '', - role: MessageRole.User, - }, - }; - - const subscription = response.subscribe({ - next: (nextPendingMessage) => { - localPendingMessage = nextPendingMessage; - setPendingMessage(nextPendingMessage); - }, - complete: () => { - setPendingMessage(undefined); - const allMessagesWithResolved = allMessages.concat({ - message: { - ...localPendingMessage.message, - }, - '@timestamp': new Date().toISOString(), - }); - if (localPendingMessage.aborted) { - setChatState(ChatState.Aborted); - setMessages(allMessagesWithResolved); - } else if (localPendingMessage.error) { - handleError(localPendingMessage.error); - setMessages(allMessagesWithResolved); - } else { - next(allMessagesWithResolved); - } - }, - error: (error) => { - handleError(error); - }, - }); - - abortController.signal.addEventListener('abort', () => { - subscription.unsubscribe(); - }); - } else { - const allMessagesWithFunctionReply = allMessages.concat({ - '@timestamp': new Date().toISOString(), - message: { - name: functionCall!.name, - role: MessageRole.User, - content: JSON.stringify(response.content), - data: JSON.stringify(response.data), - }, - }); - next(allMessagesWithFunctionReply); - } + const completedMessages: Message[] = []; + + let pendingMessage: + | { + '@timestamp': string; + message: { content: string; function_call: { name: string; arguments: string } }; + } + | undefined; + + const subscription = next$.subscribe({ + next: (event) => { + switch (event.type) { + case StreamingChatResponseEventType.ChatCompletionChunk: + if (!pendingMessage) { + pendingMessage = { + '@timestamp': new Date().toISOString(), + message: { + content: event.message.content || '', + function_call: { + name: event.message.function_call?.name || '', + arguments: event.message.function_call?.arguments || '', + }, + }, + }; + } else { + pendingMessage.message.content += event.message.content || ''; + pendingMessage.message.function_call.name += + event.message.function_call?.name || ''; + pendingMessage.message.function_call.arguments += + event.message.function_call?.arguments || ''; + } + break; + + case StreamingChatResponseEventType.MessageAdd: + pendingMessage = undefined; + completedMessages.push(event.message); + break; + + case StreamingChatResponseEventType.ConversationCreate: + case StreamingChatResponseEventType.ConversationUpdate: + onConversationUpdateRef.current?.(event); + break; + } + setPendingMessages(getPendingMessages()); + }, + complete: () => { + setChatState(ChatState.Ready); + const completed = nextMessages.concat(completedMessages); + setMessages(completed); + setPendingMessages([]); + onChatCompleteRef.current?.(completed); + }, + error: (error) => { + setPendingMessages([]); + setMessages(nextMessages.concat(getPendingMessages())); + handleError(error); + }, + }); + + abortControllerRef.current.signal.addEventListener('abort', () => { + handleSignalAbort(); + subscription.unsubscribe(); + }); }, - [connectorId, chatService, handleSignalAbort, notifications.toasts, systemMessage] + [ + connectorId, + chatService, + handleSignalAbort, + systemMessage, + handleError, + persist, + initialConversationId, + ] ); useEffect(() => { + const controller = abortControllerRef.current; return () => { - abortControllerRef.current.abort(); + controller.abort(); }; }, []); const memoizedMessages = useMemo(() => { - const includingSystemMessage = [ - systemMessage, - ...messages.filter((message) => message.message.role !== MessageRole.System), - ]; - - return pendingMessage - ? includingSystemMessage.concat({ - ...pendingMessage, - '@timestamp': new Date().toISOString(), - }) - : includingSystemMessage; - }, [systemMessage, messages, pendingMessage]); + return getWithSystemMessage(messages.concat(pendingMessages ?? []), systemMessage); + }, [systemMessage, messages, pendingMessages]); const setMessagesWithAbort = useCallback((nextMessages: Message[]) => { abortControllerRef.current.abort(); - setPendingMessage(undefined); + setPendingMessages([]); setChatState(ChatState.Ready); setMessages(nextMessages); }, []); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx index 11d5f322ddbcfa..2bf6f1910c42af 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.test.tsx @@ -4,28 +4,32 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import React from 'react'; -import { - useConversation, - type UseConversationProps, - type UseConversationResult, -} from './use_conversation'; +import type { DeeplyMockedKeys } from '@kbn/utility-types-jest'; import { act, renderHook, type RenderHookResult, type WrapperComponent, } from '@testing-library/react-hooks'; -import type { ObservabilityAIAssistantService, PendingMessage } from '../types'; -import type { DeeplyMockedKeys } from '@kbn/utility-types-jest'; -import { ObservabilityAIAssistantProvider } from '../context/observability_ai_assistant_provider'; -import * as useKibanaModule from './use_kibana'; -import { Message, MessageRole } from '../../common'; -import { ChatState } from './use_chat'; -import { createMockChatService } from '../service/create_mock_chat_service'; +import { merge } from 'lodash'; +import React from 'react'; import { Subject } from 'rxjs'; +import { MessageRole } from '../../common'; +import { + StreamingChatResponseEvent, + StreamingChatResponseEventType, +} from '../../common/conversation_complete'; +import { ObservabilityAIAssistantProvider } from '../context/observability_ai_assistant_provider'; import { EMPTY_CONVERSATION_TITLE } from '../i18n'; -import { merge, omit } from 'lodash'; +import { createMockChatService } from '../service/create_mock_chat_service'; +import type { ObservabilityAIAssistantService } from '../types'; +import { ChatState } from './use_chat'; +import { + useConversation, + type UseConversationProps, + type UseConversationResult, +} from './use_conversation'; +import * as useKibanaModule from './use_kibana'; let hookResult: RenderHookResult; @@ -269,8 +273,9 @@ describe('useConversation', () => { }); }); - describe('when chat completes without an initial conversation id', () => { - const subject: Subject = new Subject(); + describe('when chat completes', () => { + const subject: Subject = new Subject(); + let onConversationUpdate: jest.Mock; const expectedMessages = [ { '@timestamp': expect.any(String), @@ -321,6 +326,8 @@ describe('useConversation', () => { ) ); + onConversationUpdate = jest.fn(); + hookResult = renderHook(useConversation, { initialProps: { chatService: mockChatService, @@ -341,197 +348,66 @@ describe('useConversation', () => { }, }, ], + onConversationUpdate, }, wrapper, }); - mockChatService.chat.mockImplementationOnce(() => { + mockChatService.complete.mockImplementationOnce(() => { return subject; }); }); - it('the conversation is created including the initial messages', async () => { - act(() => { - hookResult.result.current.next( - hookResult.result.current.messages.concat({ - '@timestamp': new Date().toISOString(), + describe('and the conversation is created or updated', () => { + beforeEach(async () => { + await act(async () => { + hookResult.result.current.next( + hookResult.result.current.messages.concat({ + '@timestamp': new Date().toISOString(), + message: { + content: 'Hello again', + role: MessageRole.User, + }, + }) + ); + subject.next({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + id: 'my-message', message: { - role: MessageRole.User, - content: 'Hello again', + content: 'Goodbye', }, - }) - ); - subject.next({ - message: { - role: MessageRole.Assistant, - content: 'Goodbye again', - }, - }); - subject.complete(); - }); - - await act(async () => {}); - - expect(mockService.callApi.mock.calls[0]).toEqual([ - 'POST /internal/observability_ai_assistant/conversation', - { - params: { - body: { - conversation: { - '@timestamp': expect.any(String), - conversation: { - title: EMPTY_CONVERSATION_TITLE, - }, - messages: expectedMessages, - labels: {}, - numeric_labels: {}, - public: false, + }); + subject.next({ + type: StreamingChatResponseEventType.MessageAdd, + id: 'my-message', + message: { + '@timestamp': new Date().toISOString(), + message: { + content: 'Goodbye', + role: MessageRole.Assistant, }, }, - }, - signal: null, - }, - ]); - - expect(hookResult.result.current.conversation.error).toBeUndefined(); - - expect(hookResult.result.current.messages).toEqual(expectedMessages); - }); - }); - - describe('when chat completes with an initial conversation id', () => { - let subject: Subject; - - const initialMessages: Message[] = [ - { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.System, - content: '', - }, - }, - { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.User, - content: 'user', - }, - }, - { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.Assistant, - content: 'assistant', - }, - }, - ]; - - beforeEach(async () => { - mockService.callApi.mockImplementation(async (endpoint, request) => ({ - '@timestamp': new Date().toISOString(), - conversation: { - id: 'my-conversation-id', - title: EMPTY_CONVERSATION_TITLE, - }, - labels: {}, - numeric_labels: {}, - public: false, - messages: initialMessages, - })); - - hookResult = renderHook(useConversation, { - initialProps: { - chatService: mockChatService, - connectorId: 'my-connector', - initialConversationId: 'my-conversation-id', - }, - wrapper, - }); - - await act(async () => {}); - }); - - it('the conversation is loadeded', async () => { - expect(mockService.callApi.mock.calls[0]).toEqual([ - 'GET /internal/observability_ai_assistant/conversation/{conversationId}', - { - signal: expect.anything(), - params: { - path: { - conversationId: 'my-conversation-id', + }); + subject.next({ + type: StreamingChatResponseEventType.ConversationUpdate, + conversation: { + id: 'my-conversation-id', + title: 'My title', + last_updated: new Date().toISOString(), }, - }, - }, - ]); - - expect(hookResult.result.current.messages).toEqual( - initialMessages.map((msg) => ({ ...msg, '@timestamp': expect.any(String) })) - ); - }); - - describe('after chat completes', () => { - const nextUserMessage: Message = { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.User, - content: 'Hello again', - }, - }; - - const nextAssistantMessage: Message = { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.Assistant, - content: 'Goodbye again', - }, - }; - - beforeEach(async () => { - mockService.callApi.mockClear(); - subject = new Subject(); - - mockChatService.chat.mockImplementationOnce(() => { - return subject; - }); - - act(() => { - hookResult.result.current.next( - hookResult.result.current.messages.concat(nextUserMessage) - ); - subject.next(omit(nextAssistantMessage, '@timestamp')); + }); subject.complete(); }); - - await act(async () => {}); }); - it('saves the updated message', () => { - expect(mockService.callApi.mock.calls[0]).toEqual([ - 'PUT /internal/observability_ai_assistant/conversation/{conversationId}', - { - params: { - path: { - conversationId: 'my-conversation-id', - }, - body: { - conversation: { - '@timestamp': expect.any(String), - conversation: { - title: EMPTY_CONVERSATION_TITLE, - id: 'my-conversation-id', - }, - messages: initialMessages - .concat([nextUserMessage, nextAssistantMessage]) - .map((msg) => ({ ...msg, '@timestamp': expect.any(String) })), - labels: {}, - numeric_labels: {}, - public: false, - }, - }, - }, - signal: null, + it('calls the onConversationUpdate hook', () => { + expect(onConversationUpdate).toHaveBeenCalledWith({ + conversation: { + id: 'my-conversation-id', + last_updated: expect.any(String), + title: 'My title', }, - ]); + }); }); }); }); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts index c753f7c7b19292..e9d5b3f8073e43 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts @@ -38,7 +38,7 @@ export interface UseConversationProps { initialTitle?: string; chatService: ObservabilityAIAssistantChatService; connectorId: string | undefined; - onConversationUpdate?: (conversation: Conversation) => void; + onConversationUpdate?: (conversation: { conversation: Conversation['conversation'] }) => void; } export type UseConversationResult = { @@ -101,76 +101,16 @@ export function useConversation({ }); }; - const save = (nextMessages: Message[]) => { - const conversationObject = conversation.value!; - - const nextConversationObject = merge({}, omit(conversationObject, 'messages'), { - messages: nextMessages, - }); - - return ( - displayedConversationId - ? update( - merge( - { conversation: { id: displayedConversationId } }, - nextConversationObject - ) as Conversation - ) - : service - .callApi(`POST /internal/observability_ai_assistant/conversation`, { - signal: null, - params: { - body: { - conversation: nextConversationObject, - }, - }, - }) - .then((nextConversation) => { - setDisplayedConversationId(nextConversation.conversation.id); - if (connectorId) { - service - .callApi( - `PUT /internal/observability_ai_assistant/conversation/{conversationId}/auto_title`, - { - signal: null, - params: { - path: { - conversationId: nextConversation.conversation.id, - }, - body: { - connectorId, - }, - }, - } - ) - .then(() => { - onConversationUpdate?.(nextConversation); - return conversation.refresh(); - }); - } - return nextConversation; - }) - .catch((err) => { - notifications.toasts.addError(err, { - title: i18n.translate('xpack.observabilityAiAssistant.errorCreatingConversation', { - defaultMessage: 'Could not create conversation', - }), - }); - throw err; - }) - ).then((nextConversation) => { - onConversationUpdate?.(nextConversation); - return nextConversation; - }); - }; - const { next, messages, setMessages, state, stop } = useChat({ initialMessages, + initialConversationId, chatService, connectorId, - onChatComplete: (nextMessages) => { - save(nextMessages); + onConversationUpdate: (event) => { + setDisplayedConversationId(event.conversation.id); + onConversationUpdate?.({ conversation: event.conversation }); }, + persist: true, }); const [displayedConversationId, setDisplayedConversationId] = useState(initialConversationId); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_json_editor_model.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_json_editor_model.ts index e6d1f81dabd2ff..44122614abd5da 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_json_editor_model.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_json_editor_model.ts @@ -23,21 +23,19 @@ export const useJsonEditorModel = ({ }) => { const chatService = useObservabilityAIAssistantChatService(); - const functionDefinition = chatService - .getFunctions() - .find((func) => func.options.name === functionName); + const functionDefinition = chatService.getFunctions().find((func) => func.name === functionName); return useMemo(() => { if (!functionDefinition) { return {}; } - const schema = { ...functionDefinition.options.parameters }; + const schema = { ...functionDefinition.parameters }; const initialJsonString = initialJson ? initialJson - : functionDefinition.options.parameters.properties - ? JSON.stringify(createInitializedObject(functionDefinition.options.parameters), null, 4) + : functionDefinition.parameters.properties + ? JSON.stringify(createInitializedObject(functionDefinition.parameters), null, 4) : ''; languages.json.jsonDefaults.setDiagnosticsOptions({ diff --git a/x-pack/plugins/observability_ai_assistant/public/plugin.tsx b/x-pack/plugins/observability_ai_assistant/public/plugin.tsx index 6e9eb6dc731969..df8893b1f6252c 100644 --- a/x-pack/plugins/observability_ai_assistant/public/plugin.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/plugin.tsx @@ -102,16 +102,13 @@ export class ObservabilityAIAssistantPlugin enabled: coreStart.application.capabilities.observabilityAIAssistant.show === true, })); - service.register(async ({ signal, registerContext, registerFunction }) => { + service.register(async ({ registerRenderFunction }) => { const mod = await import('./functions'); return mod.registerFunctions({ service, - signal, pluginsStart, - coreStart, - registerContext, - registerFunction, + registerRenderFunction, }); }); diff --git a/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.test.ts b/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.test.ts index ba7f5e2216bc8d..472499ae4d5e73 100644 --- a/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.test.ts +++ b/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.test.ts @@ -39,6 +39,12 @@ describe('createChatService', () => { } beforeEach(async () => { + clientSpy.mockImplementationOnce(async () => { + return { + functionDefinitions: [], + contextDefinitions: [], + }; + }); service = await createChatService({ client: clientSpy, registrations: [], diff --git a/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts b/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts index 8dc61e6d48449c..2fb1a47d594ff5 100644 --- a/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts +++ b/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts @@ -4,8 +4,6 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -/* eslint-disable max-classes-per-file*/ -import { Validator, type Schema, type OutputUnit } from '@cfworker/json-schema'; import { HttpResponse } from '@kbn/core/public'; import { AbortError } from '@kbn/kibana-utils-plugin/common'; @@ -16,45 +14,82 @@ import { catchError, concatMap, delay, - filter as rxJsFilter, finalize, - map, of, scan, shareReplay, - tap, + Subject, timestamp, + map, + tap, } from 'rxjs'; import { - ContextRegistry, - FunctionRegistry, + ChatCompletionErrorCode, + ConversationCompletionError, + StreamingChatResponseEvent, + StreamingChatResponseEventType, +} from '../../common/conversation_complete'; +import { FunctionVisibility, - Message, MessageRole, - type RegisterContextDefinition, - type RegisterFunctionDefinition, + type FunctionRegistry, + type FunctionResponse, + type Message, } from '../../common/types'; -import { ObservabilityAIAssistantAPIClient } from '../api'; +import { filterFunctionDefinitions } from '../../common/utils/filter_function_definitions'; +import { processOpenAiStream } from '../../common/utils/process_openai_stream'; +import type { ObservabilityAIAssistantAPIClient } from '../api'; import type { - ChatRegistrationFunction, - CreateChatCompletionResponseChunk, + ChatRegistrationRenderFunction, ObservabilityAIAssistantChatService, PendingMessage, + RenderFunction, } from '../types'; import { readableStreamReaderIntoObservable } from '../utils/readable_stream_reader_into_observable'; -class TokenLimitReachedError extends Error { - constructor() { - super(`Token limit reached`); +const MIN_DELAY = 35; + +function toObservable(response: HttpResponse) { + const status = response.response?.status; + + if (!status || status >= 400) { + throw new Error(response.response?.statusText || 'Unexpected error'); } -} -class ServerError extends Error {} + const reader = response.response.body?.getReader(); -export class FunctionArgsValidationError extends Error { - constructor(public readonly errors: OutputUnit[]) { - super('Function arguments are invalid'); + if (!reader) { + throw new Error('Could not get reader from response'); } + + return readableStreamReaderIntoObservable(reader).pipe( + // append a timestamp of when each value was emitted + timestamp(), + // use the previous timestamp to calculate a target + // timestamp for emitting the next value + scan((acc, value) => { + const lastTimestamp = acc.timestamp || 0; + const emitAt = Math.max(lastTimestamp + MIN_DELAY, value.timestamp); + return { + timestamp: emitAt, + value: value.value, + }; + }), + // add the delay based on the elapsed time + // using concatMap(of(value).pipe(delay(50)) + // leads to browser issues because timers + // are throttled when the tab is not active + concatMap((value) => { + const now = Date.now(); + const delayFor = value.timestamp - now; + + if (delayFor <= 0) { + return of(value.value); + } + + return of(value.value).pipe(delay(delayFor)); + }) + ); } export async function createChatService({ @@ -63,72 +98,40 @@ export async function createChatService({ client, }: { signal: AbortSignal; - registrations: ChatRegistrationFunction[]; + registrations: ChatRegistrationRenderFunction[]; client: ObservabilityAIAssistantAPIClient; }): Promise { - const contextRegistry: ContextRegistry = new Map(); const functionRegistry: FunctionRegistry = new Map(); - const validators = new Map(); - - const registerContext: RegisterContextDefinition = (context) => { - contextRegistry.set(context.name, context); - }; - - const registerFunction: RegisterFunctionDefinition = (def, respond, render) => { - validators.set(def.name, new Validator(def.parameters as Schema, '2020-12', true)); - functionRegistry.set(def.name, { options: def, respond, render }); - }; - - const getContexts: ObservabilityAIAssistantChatService['getContexts'] = () => { - return Array.from(contextRegistry.values()); - }; - const getFunctions: ObservabilityAIAssistantChatService['getFunctions'] = ({ - contexts, - filter, - } = {}) => { - const allFunctions = Array.from(functionRegistry.values()); + const renderFunctionRegistry: Map> = new Map(); - return contexts || filter - ? allFunctions.filter((fn) => { - const matchesContext = - !contexts || fn.options.contexts.some((context) => contexts.includes(context)); - const matchesFilter = - !filter || fn.options.name.includes(filter) || fn.options.description.includes(filter); - - return matchesContext && matchesFilter; - }) - : allFunctions; + const [{ functionDefinitions, contextDefinitions }] = await Promise.all([ + client('GET /internal/observability_ai_assistant/functions', { + signal: setupAbortSignal, + }), + ...registrations.map((registration) => { + return registration({ + registerRenderFunction: (name, renderFn) => { + renderFunctionRegistry.set(name, renderFn); + }, + }); + }), + ]); + + functionDefinitions.forEach((fn) => { + functionRegistry.set(fn.name, fn); + }); + + const getFunctions = (options?: { contexts?: string[]; filter?: string }) => { + return filterFunctionDefinitions({ + ...options, + definitions: functionDefinitions, + }); }; - await Promise.all( - registrations.map((fn) => fn({ signal: setupAbortSignal, registerContext, registerFunction })) - ); - - function validate(name: string, parameters: unknown) { - const validator = validators.get(name)!; - const result = validator.validate(parameters); - if (!result.valid) { - throw new FunctionArgsValidationError(result.errors); - } - } - return { - executeFunction: async ({ name, args, signal, messages, connectorId }) => { - const fn = functionRegistry.get(name); - - if (!fn) { - throw new Error(`Function ${name} not found`); - } - - const parsedArguments = args ? JSON.parse(args) : {}; - - validate(name, parsedArguments); - - return await fn.respond({ arguments: parsedArguments, messages, connectorId }, signal); - }, renderFunction: (name, args, response) => { - const fn = functionRegistry.get(name); + const fn = renderFunctionRegistry.get(name); if (!fn) { throw new Error(`Function ${name} not found`); @@ -141,15 +144,57 @@ export async function createChatService({ data: JSON.parse(response.data ?? '{}'), }; - return fn.render?.({ response: parsedResponse, arguments: parsedArguments }); + return fn?.({ response: parsedResponse, arguments: parsedArguments }); }, - getContexts, + getContexts: () => contextDefinitions, getFunctions, hasFunction: (name: string) => { - return !!getFunctions().find((fn) => fn.options.name === name); + return functionRegistry.has(name); }, hasRenderFunction: (name: string) => { - return !!getFunctions().find((fn) => fn.options.name === name)?.render; + return renderFunctionRegistry.has(name); + }, + complete({ connectorId, messages, conversationId, persist, signal }) { + const subject = new Subject(); + + client('POST /internal/observability_ai_assistant/chat/complete', { + params: { + body: { + messages, + connectorId, + conversationId, + persist, + }, + }, + signal, + asResponse: true, + rawResponse: true, + }) + .then((_response) => { + const response = _response as unknown as HttpResponse; + const response$ = toObservable(response) + .pipe( + map((line) => JSON.parse(line) as StreamingChatResponseEvent), + tap((event) => { + if (event.type === StreamingChatResponseEventType.ConversationCompletionError) { + const code = event.error.code ?? ChatCompletionErrorCode.InternalError; + const message = event.error.message; + throw new ConversationCompletionError(code, message); + } + }) + ) + .subscribe(subject); + + signal.addEventListener('abort', () => { + response$.unsubscribe(); + }); + }) + .catch((err) => { + subject.error(err); + subject.complete(); + }); + + return subject; }, chat({ connectorId, @@ -181,8 +226,8 @@ export async function createChatService({ callFunctions === 'none' ? [] : functions - .filter((fn) => fn.options.visibility !== FunctionVisibility.User) - .map((fn) => pick(fn.options, 'name', 'description', 'parameters')), + .filter((fn) => fn.visibility !== FunctionVisibility.User) + .map((fn) => pick(fn, 'name', 'description', 'parameters')), }, }, signal: controller.signal, @@ -192,51 +237,9 @@ export async function createChatService({ .then((_response) => { const response = _response as unknown as HttpResponse; - const status = response.response?.status; - - if (!status || status >= 400) { - throw new Error(response.response?.statusText || 'Unexpected error'); - } - - const reader = response.response.body?.getReader(); - - if (!reader) { - throw new Error('Could not get reader from response'); - } - - const subscription = readableStreamReaderIntoObservable(reader) + const subscription = toObservable(response) .pipe( - // lines start with 'data: ' - map((line) => line.substring(6)), - // a message completes with the line '[DONE]' - rxJsFilter((line) => !!line && line !== '[DONE]'), - // parse the JSON, add the type - map( - (line) => - JSON.parse(line) as - | CreateChatCompletionResponseChunk - | { error: { message: string } } - ), - // validate the message. in some cases OpenAI - // will throw halfway through the message - tap((line) => { - if ('error' in line) { - throw new ServerError(line.error.message); - } - }), - // there also might be some metadata that we need - // to exclude - rxJsFilter( - (line): line is CreateChatCompletionResponseChunk => - 'object' in line && line.object === 'chat.completion.chunk' - ), - // this is how OpenAI signals that the context window - // limit has been exceeded - tap((line) => { - if (line.choices[0].finish_reason === 'length') { - throw new TokenLimitReachedError(); - } - }), + processOpenAiStream(), // merge the messages scan( (acc, { choices }) => { @@ -298,8 +301,6 @@ export async function createChatService({ subject.complete(); }); - const MIN_DELAY = 35; - const pendingMessages$ = subject.pipe( // make sure the request is only triggered once, // even with multiple subscribers @@ -308,32 +309,6 @@ export async function createChatService({ // abort the running request finalize(() => { controller.abort(); - }), - // append a timestamp of when each value was emitted - timestamp(), - // use the previous timestamp to calculate a target - // timestamp for emitting the next value - scan((acc, value) => { - const lastTimestamp = acc.timestamp || 0; - const emitAt = Math.max(lastTimestamp + MIN_DELAY, value.timestamp); - return { - timestamp: emitAt, - value: value.value, - }; - }), - // add the delay based on the elapsed time - // using concatMap(of(value).pipe(delay(50)) - // leads to browser issues because timers - // are throttled when the tab is not active - concatMap((value) => { - const now = Date.now(); - const delayFor = value.timestamp - now; - - if (delayFor <= 0) { - return of(value.value); - } - - return of(value.value).pipe(delay(delayFor)); }) ); diff --git a/x-pack/plugins/observability_ai_assistant/public/service/create_mock_chat_service.ts b/x-pack/plugins/observability_ai_assistant/public/service/create_mock_chat_service.ts index e255aa830467e4..270a4e62f5fc2c 100644 --- a/x-pack/plugins/observability_ai_assistant/public/service/create_mock_chat_service.ts +++ b/x-pack/plugins/observability_ai_assistant/public/service/create_mock_chat_service.ts @@ -13,7 +13,7 @@ type MockedChatService = DeeplyMockedKeys; export const createMockChatService = (): MockedChatService => { const mockChatService: MockedChatService = { chat: jest.fn(), - executeFunction: jest.fn(), + complete: jest.fn(), getContexts: jest.fn().mockReturnValue([{ name: 'core', description: '' }]), getFunctions: jest.fn().mockReturnValue([]), hasFunction: jest.fn().mockReturnValue(false), diff --git a/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts b/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts index 5e7356250d5156..f686ee3fbfd866 100644 --- a/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts +++ b/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts @@ -10,7 +10,7 @@ import type { LicensingPluginStart } from '@kbn/licensing-plugin/public'; import type { SecurityPluginStart } from '@kbn/security-plugin/public'; import type { SharePluginStart } from '@kbn/share-plugin/public'; import { createCallObservabilityAIAssistantAPI } from '../api'; -import type { ChatRegistrationFunction, ObservabilityAIAssistantService } from '../types'; +import type { ChatRegistrationRenderFunction, ObservabilityAIAssistantService } from '../types'; export function createService({ coreStart, @@ -24,10 +24,10 @@ export function createService({ licenseStart: LicensingPluginStart; securityStart: SecurityPluginStart; shareStart: SharePluginStart; -}): ObservabilityAIAssistantService & { register: (fn: ChatRegistrationFunction) => void } { +}): ObservabilityAIAssistantService & { register: (fn: ChatRegistrationRenderFunction) => void } { const client = createCallObservabilityAIAssistantAPI(coreStart); - const registrations: ChatRegistrationFunction[] = []; + const registrations: ChatRegistrationRenderFunction[] = []; return { isEnabled: () => { @@ -40,7 +40,6 @@ export function createService({ const mod = await import('./create_chat_service'); return await mod.createChatService({ client, signal, registrations }); }, - callApi: client, getCurrentUser: () => securityStart.authc.getCurrentUser(), getLicense: () => licenseStart.license$, diff --git a/x-pack/plugins/observability_ai_assistant/public/types.ts b/x-pack/plugins/observability_ai_assistant/public/types.ts index 99853b7b313b7c..bf8d0d6870b337 100644 --- a/x-pack/plugins/observability_ai_assistant/public/types.ts +++ b/x-pack/plugins/observability_ai_assistant/public/types.ts @@ -18,11 +18,6 @@ import type { TriggersAndActionsUIPublicPluginSetup, TriggersAndActionsUIPublicPluginStart, } from '@kbn/triggers-actions-ui-plugin/public'; -import type { Serializable } from '@kbn/utility-types'; -import type { - CreateChatCompletionResponse, - CreateChatCompletionResponseChoicesInner, -} from 'openai'; import type { Observable } from 'rxjs'; import type { LensPublicSetup, LensPublicStart } from '@kbn/lens-plugin/public'; import type { @@ -34,22 +29,16 @@ import type { SharePluginStart } from '@kbn/share-plugin/public'; import type { ContextDefinition, FunctionDefinition, + FunctionResponse, Message, - RegisterContextDefinition, - RegisterFunctionDefinition, } from '../common/types'; import type { ObservabilityAIAssistantAPIClient } from './api'; import type { PendingMessage } from '../common/types'; +import type { StreamingChatResponseEvent } from '../common/conversation_complete'; /* eslint-disable @typescript-eslint/no-empty-interface*/ -export type CreateChatCompletionResponseChunk = Omit & { - choices: Array< - Omit & { - delta: { content?: string; function_call?: { name?: string; arguments?: string } }; - } - >; -}; +export type { CreateChatCompletionResponseChunk } from '../common/types'; export interface ObservabilityAIAssistantChatService { chat: (options: { @@ -57,17 +46,17 @@ export interface ObservabilityAIAssistantChatService { connectorId: string; function?: 'none' | 'auto'; }) => Observable; + complete: (options: { + messages: Message[]; + connectorId: string; + persist: boolean; + conversationId?: string; + signal: AbortSignal; + }) => Observable; getContexts: () => ContextDefinition[]; getFunctions: (options?: { contexts?: string[]; filter?: string }) => FunctionDefinition[]; hasFunction: (name: string) => boolean; hasRenderFunction: (name: string) => boolean; - executeFunction: ({}: { - name: string; - args: string | undefined; - messages: Message[]; - signal: AbortSignal; - connectorId: string; - }) => Promise<{ content?: Serializable; data?: Serializable } | Observable>; renderFunction: ( name: string, args: string | undefined, @@ -75,12 +64,6 @@ export interface ObservabilityAIAssistantChatService { ) => React.ReactNode; } -export type ChatRegistrationFunction = ({}: { - signal: AbortSignal; - registerFunction: RegisterFunctionDefinition; - registerContext: RegisterContextDefinition; -}) => Promise; - export interface ObservabilityAIAssistantService { isEnabled: () => boolean; callApi: ObservabilityAIAssistantAPIClient; @@ -90,8 +73,22 @@ export interface ObservabilityAIAssistantService { start: ({}: { signal: AbortSignal }) => Promise; } +export type RenderFunction = (options: { + arguments: TArguments; + response: TResponse; +}) => React.ReactNode; + +export type RegisterRenderFunctionDefinition< + TFunctionArguments = any, + TFunctionResponse extends FunctionResponse = FunctionResponse +> = (name: string, render: RenderFunction) => void; + +export type ChatRegistrationRenderFunction = ({}: { + registerRenderFunction: RegisterRenderFunctionDefinition; +}) => Promise; + export interface ObservabilityAIAssistantPluginStart extends ObservabilityAIAssistantService { - register: (fn: ChatRegistrationFunction) => void; + register: (fn: ChatRegistrationRenderFunction) => void; } export interface ObservabilityAIAssistantPluginSetup {} diff --git a/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts b/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts index 6f2d1e5c2f0905..f54995d3d17cd4 100644 --- a/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts +++ b/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts @@ -7,7 +7,7 @@ import { merge, uniqueId } from 'lodash'; import { DeepPartial } from 'utility-types'; -import { MessageRole, Conversation, FunctionDefinition, Message } from '../../common/types'; +import { Conversation, FunctionDefinition, Message, MessageRole } from '../../common/types'; import { getAssistantSetupMessage } from '../service/get_assistant_setup_message'; type BuildMessageProps = DeepPartial & { @@ -121,28 +121,25 @@ export function buildConversation(params?: Partial) { export function buildFunction(): FunctionDefinition { return { - options: { - name: 'elasticsearch', - contexts: ['core'], - description: 'Call Elasticsearch APIs on behalf of the user', - descriptionForUser: 'Call Elasticsearch APIs on behalf of the user', - parameters: { - type: 'object', - properties: { - method: { - type: 'string', - description: 'The HTTP method of the Elasticsearch endpoint', - enum: ['GET', 'PUT', 'POST', 'DELETE', 'PATCH'] as const, - }, - path: { - type: 'string', - description: 'The path of the Elasticsearch endpoint, including query parameters', - }, + name: 'elasticsearch', + contexts: ['core'], + description: 'Call Elasticsearch APIs on behalf of the user', + descriptionForUser: 'Call Elasticsearch APIs on behalf of the user', + parameters: { + type: 'object', + properties: { + method: { + type: 'string', + description: 'The HTTP method of the Elasticsearch endpoint', + enum: ['GET', 'PUT', 'POST', 'DELETE', 'PATCH'] as const, + }, + path: { + type: 'string', + description: 'The path of the Elasticsearch endpoint, including query parameters', }, - required: ['method' as const, 'path' as const], }, + required: ['method' as const, 'path' as const], }, - respond: async (options: { arguments: any }, signal: AbortSignal) => ({}), }; } @@ -150,16 +147,13 @@ export const buildFunctionElasticsearch = buildFunction; export function buildFunctionServiceSummary(): FunctionDefinition { return { - options: { - name: 'get_service_summary', - contexts: ['core'], - description: - 'Gets a summary of a single service, including: the language, service version, deployments, infrastructure, alerting, etc. ', - descriptionForUser: 'Get a summary for a single service.', - parameters: { - type: 'object', - }, + name: 'get_service_summary', + contexts: ['core'], + description: + 'Gets a summary of a single service, including: the language, service version, deployments, infrastructure, alerting, etc. ', + descriptionForUser: 'Get a summary for a single service.', + parameters: { + type: 'object', }, - respond: async (options: { arguments: any }, signal: AbortSignal) => ({}), }; } diff --git a/x-pack/plugins/observability_ai_assistant/public/utils/create_initialized_object.ts b/x-pack/plugins/observability_ai_assistant/public/utils/create_initialized_object.ts index 42f314766dca90..6ae23042a63e65 100644 --- a/x-pack/plugins/observability_ai_assistant/public/utils/create_initialized_object.ts +++ b/x-pack/plugins/observability_ai_assistant/public/utils/create_initialized_object.ts @@ -7,7 +7,7 @@ import { FunctionDefinition } from '../../common/types'; -type Params = FunctionDefinition['options']['parameters']; +type Params = FunctionDefinition['parameters']; export function createInitializedObject(parameters: Params) { const emptyObject: Record = {}; diff --git a/x-pack/plugins/observability_ai_assistant/public/utils/storybook_decorator.tsx b/x-pack/plugins/observability_ai_assistant/public/utils/storybook_decorator.tsx index 68dd784a1f8aa5..1c51253da1222c 100644 --- a/x-pack/plugins/observability_ai_assistant/public/utils/storybook_decorator.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/utils/storybook_decorator.tsx @@ -4,35 +4,35 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import React, { ComponentType } from 'react'; -import { Observable } from 'rxjs'; +import { i18n } from '@kbn/i18n'; import { KibanaContextProvider } from '@kbn/kibana-react-plugin/public'; -import type { Serializable } from '@kbn/utility-types'; import type { AuthenticatedUser } from '@kbn/security-plugin/common'; import type { SharePluginStart } from '@kbn/share-plugin/public'; -import { ObservabilityAIAssistantProvider } from '../context/observability_ai_assistant_provider'; +import React, { ComponentType } from 'react'; +import { Observable } from 'rxjs'; +import { StreamingChatResponseEvent } from '../../common/conversation_complete'; import { ObservabilityAIAssistantAPIClient } from '../api'; -import type { Message } from '../../common'; +import { ObservabilityAIAssistantChatServiceProvider } from '../context/observability_ai_assistant_chat_service_provider'; +import { ObservabilityAIAssistantProvider } from '../context/observability_ai_assistant_provider'; import type { ObservabilityAIAssistantChatService, ObservabilityAIAssistantService, PendingMessage, } from '../types'; import { buildFunctionElasticsearch, buildFunctionServiceSummary } from './builders'; -import { ObservabilityAIAssistantChatServiceProvider } from '../context/observability_ai_assistant_chat_service_provider'; const chatService: ObservabilityAIAssistantChatService = { - chat: (options: { messages: Message[]; connectorId: string }) => new Observable(), + chat: (options) => new Observable(), + complete: (options) => new Observable(), getContexts: () => [], getFunctions: () => [buildFunctionElasticsearch(), buildFunctionServiceSummary()], - executeFunction: async ({}: { - name: string; - args: string | undefined; - messages: Message[]; - signal: AbortSignal; - }): Promise<{ content?: Serializable; data?: Serializable }> => ({}), - renderFunction: (name: string, args: string | undefined, response: {}) => ( -
Hello! {name}
+ renderFunction: (name) => ( +
+ {i18n.translate('xpack.observabilityAiAssistant.chatService.div.helloLabel', { + defaultMessage: 'Hello', + })} + {name} +
), hasFunction: () => true, hasRenderFunction: () => true, diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/alerts.ts b/x-pack/plugins/observability_ai_assistant/server/functions/alerts.ts new file mode 100644 index 00000000000000..e58d75c52cf742 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/server/functions/alerts.ts @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import datemath from '@elastic/datemath'; +import { fromKueryExpression, toElasticsearchQuery } from '@kbn/es-query'; +import { ParsedTechnicalFields } from '@kbn/rule-registry-plugin/common'; +import { + ALERT_STATUS, + ALERT_STATUS_ACTIVE, +} from '@kbn/rule-registry-plugin/common/technical_rule_data_field_names'; +import { omit } from 'lodash'; +import { FunctionRegistrationParameters } from '.'; + +const OMITTED_ALERT_FIELDS = [ + 'tags', + 'event.action', + 'event.kind', + 'kibana.alert.rule.execution.uuid', + 'kibana.alert.rule.revision', + 'kibana.alert.rule.tags', + 'kibana.alert.rule.uuid', + 'kibana.alert.workflow_status', + 'kibana.space_ids', + 'kibana.alert.time_range', + 'kibana.version', +] as const; + +const DEFAULT_FEATURE_IDS = [ + 'apm', + 'infrastructure', + 'logs', + 'uptime', + 'slo', + 'observability', +] as const; + +export function registerAlertsFunction({ + client, + registerFunction, + resources, +}: FunctionRegistrationParameters) { + registerFunction( + { + name: 'alerts', + contexts: ['core'], + description: + 'Get alerts for Observability. Display the response in tabular format if appropriate.', + descriptionForUser: 'Get alerts for Observability', + parameters: { + type: 'object', + additionalProperties: false, + properties: { + featureIds: { + type: 'array', + additionalItems: false, + items: { + type: 'string', + enum: DEFAULT_FEATURE_IDS, + }, + description: + 'The Observability apps for which to retrieve alerts. By default it will return alerts for all apps.', + }, + start: { + type: 'string', + description: 'The start of the time range, in Elasticsearch date math, like `now`.', + }, + end: { + type: 'string', + description: 'The end of the time range, in Elasticsearch date math, like `now-24h`.', + }, + filter: { + type: 'string', + description: + 'a KQL query to filter the data by. If no filter should be applied, leave it empty.', + }, + includeRecovered: { + type: 'boolean', + description: + 'Whether to include recovered/closed alerts. Defaults to false, which means only active alerts will be returned', + }, + }, + required: ['start', 'end'], + } as const, + }, + async ( + { + arguments: { + start: startAsDatemath, + end: endAsDatemath, + featureIds, + filter, + includeRecovered, + }, + }, + signal + ) => { + const racContext = await resources.context.rac; + const alertsClient = await racContext.getAlertsClient(); + + const start = datemath.parse(startAsDatemath)!.valueOf(); + const end = datemath.parse(endAsDatemath)!.valueOf(); + + const kqlQuery = !filter ? [] : [toElasticsearchQuery(fromKueryExpression(filter))]; + + const response = await alertsClient.find({ + featureIds: + !!featureIds && !!featureIds.length + ? featureIds + : (DEFAULT_FEATURE_IDS as unknown as string[]), + query: { + bool: { + filter: [ + { + range: { + '@timestamp': { + gte: start, + lte: end, + }, + }, + }, + ...kqlQuery, + ...(!includeRecovered + ? [ + { + term: { + [ALERT_STATUS]: ALERT_STATUS_ACTIVE, + }, + }, + ] + : []), + ], + }, + }, + }); + + // trim some fields + const alerts = response.hits.hits.map((hit) => + omit(hit._source, ...OMITTED_ALERT_FIELDS) + ) as unknown as ParsedTechnicalFields[]; + + return { + content: { + total: (response.hits as { total: { value: number } }).total.value, + alerts, + }, + }; + } + ); +} diff --git a/x-pack/plugins/observability_ai_assistant/public/functions/elasticsearch.ts b/x-pack/plugins/observability_ai_assistant/server/functions/elasticsearch.ts similarity index 65% rename from x-pack/plugins/observability_ai_assistant/public/functions/elasticsearch.ts rename to x-pack/plugins/observability_ai_assistant/server/functions/elasticsearch.ts index 546bd2bea45749..44cb30233504ab 100644 --- a/x-pack/plugins/observability_ai_assistant/public/functions/elasticsearch.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/elasticsearch.ts @@ -5,17 +5,12 @@ * 2.0. */ -import type { Serializable } from '@kbn/utility-types'; -import type { RegisterFunctionDefinition } from '../../common/types'; -import type { ObservabilityAIAssistantService } from '../types'; +import type { FunctionRegistrationParameters } from '.'; export function registerElasticsearchFunction({ - service, registerFunction, -}: { - service: ObservabilityAIAssistantService; - registerFunction: RegisterFunctionDefinition; -}) { + resources, +}: FunctionRegistrationParameters) { registerFunction( { name: 'elasticsearch', @@ -43,19 +38,16 @@ export function registerElasticsearchFunction({ required: ['method', 'path'] as const, }, }, - ({ arguments: { method, path, body } }, signal) => { - return service - .callApi(`POST /internal/observability_ai_assistant/functions/elasticsearch`, { - signal, - params: { - body: { - method, - path, - body, - }, - }, - }) - .then((response) => ({ content: response as Serializable })); + async ({ arguments: { method, path, body } }) => { + const response = await ( + await resources.context.core + ).elasticsearch.client.asCurrentUser.transport.request({ + method, + path, + body, + }); + + return { content: response }; } ); } diff --git a/x-pack/plugins/observability_ai_assistant/public/functions/esql.ts b/x-pack/plugins/observability_ai_assistant/server/functions/esql.ts similarity index 89% rename from x-pack/plugins/observability_ai_assistant/public/functions/esql.ts rename to x-pack/plugins/observability_ai_assistant/server/functions/esql.ts index 56c3c833608214..88997452c0ad8e 100644 --- a/x-pack/plugins/observability_ai_assistant/public/functions/esql.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/esql.ts @@ -6,22 +6,21 @@ */ import dedent from 'dedent'; -import type { Serializable } from '@kbn/utility-types'; -import { concat, last, map } from 'rxjs'; +import { Observable } from 'rxjs'; +import type { FunctionRegistrationParameters } from '.'; import { + type CreateChatCompletionResponseChunk, FunctionVisibility, MessageRole, - type RegisterFunctionDefinition, } from '../../common/types'; -import type { ObservabilityAIAssistantService } from '../types'; +import { processOpenAiStream } from '../../common/utils/process_openai_stream'; +import { streamIntoObservable } from '../service/util/stream_into_observable'; export function registerEsqlFunction({ - service, + client, registerFunction, -}: { - service: ObservabilityAIAssistantService; - registerFunction: RegisterFunctionDefinition; -}) { + resources, +}: FunctionRegistrationParameters) { registerFunction( { name: 'execute_query', @@ -39,21 +38,18 @@ export function registerEsqlFunction({ required: ['query'], } as const, }, - ({ arguments: { query } }, signal) => { - return service - .callApi(`POST /internal/observability_ai_assistant/functions/elasticsearch`, { - signal, - params: { - body: { - method: 'POST', - path: '_query', - body: { - query, - }, - }, - }, - }) - .then((response) => ({ content: response as Serializable })); + async ({ arguments: { query } }) => { + const response = await ( + await resources.context.core + ).elasticsearch.client.asCurrentUser.transport.request({ + method: 'POST', + path: '_query', + body: { + query, + }, + }); + + return { content: response }; } ); @@ -73,10 +69,10 @@ export function registerEsqlFunction({ }, } as const, }, - ({ messages, connectorId }, signal) => { + async ({ messages, connectorId }, signal) => { const systemMessage = dedent(`You are a helpful assistant for Elastic ES|QL. Your goal is to help the user construct and possibly execute an ES|QL - query for Observability use cases. + query for Observability use cases. ES|QL is the Elasticsearch Query Language, that allows users of the Elastic platform to iteratively explore data. An ES|QL query consists @@ -92,7 +88,7 @@ export function registerEsqlFunction({ the context of this conversation. # Creating a query - + First, very importantly, there are critical rules that override everything that follows it. Always repeat these rules, verbatim. @@ -144,6 +140,10 @@ export function registerEsqlFunction({ -- Let's break down the query step-by-step: + + \`\`\`esql + + \`\`\` \`\`\` Always format a complete query as follows: @@ -203,7 +203,7 @@ export function registerEsqlFunction({ - \`1 year\` - \`2 milliseconds\` - ## Aliasing + ## Aliasing Aliasing happens through the \`=\` operator. Example: \`STATS total_salary_expenses = COUNT(salary)\` @@ -211,7 +211,7 @@ export function registerEsqlFunction({ # Source commands - There are three source commands: FROM (which selects an index), ROW + There are three source commands: FROM (which selects an index), ROW (which creates data from the command) and SHOW (which returns information about the deployment). You do not support SHOW for now. @@ -276,10 +276,10 @@ export function registerEsqlFunction({ This is right: \`| STATS avg_cpu = AVG(cpu) | SORT avg_cpu\` ### EVAL - + \`EVAL\` appends a new column to the documents by using aliasing. It also supports functions, but not aggregation functions like COUNT: - + - \`\`\` | EVAL monthly_salary = yearly_salary / 12, total_comp = ROUND(yearly_salary + yearly+bonus), @@ -396,7 +396,7 @@ export function registerEsqlFunction({ can be expressed using the timespan literal syntax. Use this together with STATS ... BY to group data into time buckets with a fixed interval. Some examples: - + - \`| EVAL year_hired = DATE_TRUNC(1 year, hire_date)\` - \`| EVAL month_logged = DATE_TRUNC(1 month, @timestamp)\` - \`| EVAL bucket = DATE_TRUNC(1 minute, @timestamp) | STATS avg_salary = AVG(salary) BY bucket\` @@ -431,7 +431,7 @@ export function registerEsqlFunction({ Returns the greatest or least of two or numbers. Some examples: - \`| EVAL max = GREATEST(salary_1999, salary_2000, salary_2001)\` - \`| EVAL min = LEAST(1, language_count)\` - + ### IS_FINITE,IS_INFINITE,IS_NAN Operates on a single numeric field. Some examples: @@ -459,7 +459,7 @@ export function registerEsqlFunction({ - \`| EVAL version = TO_VERSION("1.2.3")\` - \`| EVAL as_bool = TO_BOOLEAN(my_boolean_string)\` - \`| EVAL percent = TO_DOUBLE(part) / TO_DOUBLE(total)\` - + ### TRIM Trims leading and trailing whitespace. Some examples: @@ -482,7 +482,7 @@ export function registerEsqlFunction({ argument, and does not support wildcards. One single argument is required. If you don't have a field name, use whatever field you have, rather than displaying an invalid query. - + Some examples: - \`| STATS doc_count = COUNT(emp_no)\` @@ -496,16 +496,16 @@ export function registerEsqlFunction({ - \`| STATS first_name = COUNT_DISTINCT(first_name)\` ### PERCENTILE - + \`PERCENTILE\` returns the percentile value for a specific field. Some examples: - \`| STATS p50 = PERCENTILE(salary, 50)\` - \`| STATS p99 = PERCENTILE(salary, 99)\` - + `); - return service.start({ signal }).then((client) => { - const source$ = client.chat({ + const source$ = streamIntoObservable( + await client.chat({ connectorId, messages: [ { @@ -514,46 +514,48 @@ export function registerEsqlFunction({ }, ...messages.slice(1), ], - }); + signal, + stream: true, + }) + ).pipe(processOpenAiStream()); - const pending$ = source$.pipe( - map((message) => { - const content = message.message.content || ''; - let next: string = ''; - - if (content.length <= 2) { - next = ''; - } else if (content.includes('--')) { - next = message.message.content?.split('--')[2] || ''; - } else { - next = content; - } + return new Observable((subscriber) => { + let cachedContent: string = ''; - return { - ...message, - message: { - ...message.message, - content: next, - }, - }; - }) - ); - const onComplete$ = source$.pipe( - last(), - map((message) => { - const [, , next] = message.message.content?.split('--') ?? []; - - return { - ...message, - message: { - ...message.message, - content: next || message.message.content, - }, - }; - }) - ); - - return concat(pending$, onComplete$); + function includesDivider() { + const firstDividerIndex = cachedContent.indexOf('--'); + return firstDividerIndex !== -1 && cachedContent.lastIndexOf('--') !== firstDividerIndex; + } + + source$.subscribe({ + next: (message) => { + if (includesDivider()) { + subscriber.next(message); + } + cachedContent += message.choices[0].delta.content || ''; + }, + complete: () => { + if (!includesDivider()) { + subscriber.next({ + created: 0, + id: '', + model: '', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: cachedContent, + }, + }, + ], + }); + } + subscriber.complete(); + }, + error: (error) => { + subscriber.error(error); + }, + }); }); } ); diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts b/x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts new file mode 100644 index 00000000000000..df9cd0cd231d62 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts @@ -0,0 +1,191 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { chunk, groupBy, uniq } from 'lodash'; +import { lastValueFrom } from 'rxjs'; +import { FunctionRegistrationParameters } from '.'; +import { FunctionVisibility, MessageRole } from '../../common/types'; +import { concatenateOpenAiChunks } from '../../common/utils/concatenate_openai_chunks'; +import { processOpenAiStream } from '../../common/utils/process_openai_stream'; +import { streamIntoObservable } from '../service/util/stream_into_observable'; + +export function registerGetDatasetInfoFunction({ + client, + resources, + registerFunction, +}: FunctionRegistrationParameters) { + registerFunction( + { + name: 'get_dataset_info', + contexts: ['core'], + visibility: FunctionVisibility.System, + description: `Use this function to get information about indices/datasets available and the fields available on them. + + providing empty string as index name will retrieve all indices + else list of all fields for the given index will be given. if no fields are returned this means no indices were matched by provided index pattern. + wildcards can be part of index name.`, + descriptionForUser: + 'This function allows the assistant to get information about available indices and their fields.', + parameters: { + type: 'object', + additionalProperties: false, + properties: { + index: { + type: 'string', + description: + 'index pattern the user is interested in or empty string to get information about all available indices', + }, + }, + required: ['index'], + } as const, + }, + async ({ arguments: { index }, messages, connectorId }, signal) => { + const coreContext = await resources.context.core; + + const esClient = coreContext.elasticsearch.client.asCurrentUser; + const savedObjectsClient = coreContext.savedObjects.getClient(); + + let indices: string[] = []; + + try { + const body = await esClient.indices.resolveIndex({ + name: index === '' ? '*' : index, + expand_wildcards: 'open', + }); + indices = [...body.indices.map((i) => i.name), ...body.data_streams.map((d) => d.name)]; + } catch (e) { + indices = []; + } + + if (index === '') { + return { + indices, + fields: [], + }; + } + + if (indices.length === 0) { + return { + indices, + fields: [], + }; + } + + const fields = await resources.plugins.dataViews + .start() + .then((dataViewsStart) => + dataViewsStart.dataViewsServiceFactory(savedObjectsClient, esClient) + ) + .then((service) => + service.getFieldsForWildcard({ + pattern: index, + }) + ); + + // else get all the fields for the found dataview + const response = { + indices: [index], + fields: fields.flatMap((field) => { + return (field.esTypes ?? [field.type]).map((type) => { + return { + name: field.name, + description: field.customLabel || '', + type, + }; + }); + }), + }; + + const allFields = response.fields; + + const fieldNames = uniq(allFields.map((field) => field.name)); + + const groupedFields = groupBy(allFields, (field) => field.name); + + const relevantFields = await Promise.all( + chunk(fieldNames, 500).map(async (fieldsInChunk) => { + const chunkResponse$ = streamIntoObservable( + await client.chat({ + connectorId, + signal, + messages: [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: `You are a helpful assistant for Elastic Observability. + Your task is to create a list of field names that are relevant + to the conversation, using ONLY the list of fields and + types provided in the last user message. DO NOT UNDER ANY + CIRCUMSTANCES include fields not mentioned in this list.`, + }, + }, + ...messages.slice(1), + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: `This is the list: + + ${fieldsInChunk.join('\n')}`, + }, + }, + ], + functions: [ + { + name: 'fields', + description: 'The fields you consider relevant to the conversation', + parameters: { + type: 'object', + additionalProperties: false, + properties: { + fields: { + type: 'array', + additionalProperties: false, + items: { + type: 'string', + additionalProperties: false, + }, + }, + }, + required: ['fields'], + } as const, + }, + ], + functionCall: 'fields', + stream: true, + }) + ).pipe(processOpenAiStream(), concatenateOpenAiChunks()); + + const chunkResponse = await lastValueFrom(chunkResponse$); + + return chunkResponse.message?.function_call?.arguments + ? ( + JSON.parse(chunkResponse.message.function_call.arguments) as { + fields: string[]; + } + ).fields + .filter((field) => fieldNames.includes(field)) + .map((field) => { + const fieldDescriptors = groupedFields[field]; + return `${field}:${fieldDescriptors + .map((descriptor) => descriptor.type) + .join(',')}`; + }) + : [chunkResponse.message?.content ?? '']; + }) + ); + + return { + content: { + indices: response.indices, + fields: relevantFields.flat(), + }, + }; + } + ); +} diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/index.ts b/x-pack/plugins/observability_ai_assistant/server/functions/index.ts new file mode 100644 index 00000000000000..db673a725f2b85 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/server/functions/index.ts @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import dedent from 'dedent'; +import { registerRecallFunction } from './recall'; +import { registerSummarizationFunction } from './summarize'; +import { ChatRegistrationFunction } from '../service/types'; +import { registerAlertsFunction } from './alerts'; +import { registerElasticsearchFunction } from './elasticsearch'; +import { registerEsqlFunction } from './esql'; +import { registerGetDatasetInfoFunction } from './get_dataset_info'; +import { registerLensFunction } from './lens'; + +export type FunctionRegistrationParameters = Omit< + Parameters[0], + 'registerContext' +>; + +export const registerFunctions: ChatRegistrationFunction = async ({ + client, + registerContext, + registerFunction, + resources, + signal, +}) => { + const registrationParameters: FunctionRegistrationParameters = { + client, + registerFunction, + resources, + signal, + }; + return client.getKnowledgeBaseStatus().then((response) => { + const isReady = response.ready; + + let description = dedent( + `You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observability users to quickly assess what is happening in their observed systems. You can help them visualise and analyze data, investigate their systems, perform root cause analysis or identify optimisation opportunities. + + It's very important to not assume what the user is meaning. Ask them for clarification if needed. + + If you are unsure about which function should be used and with what arguments, ask the user for clarification or confirmation. + + In KQL, escaping happens with double quotes, not single quotes. Some characters that need escaping are: ':()\\\ + /\". Always put a field value in double quotes. Best: service.name:\"opbeans-go\". Wrong: service.name:opbeans-go. This is very important! + + You can use Github-flavored Markdown in your responses. If a function returns an array, consider using a Markdown table to format the response. + + If multiple functions are suitable, use the most specific and easy one. E.g., when the user asks to visualise APM data, use the APM functions (if available) rather than Lens. + + If a function call fails, DO NOT UNDER ANY CIRCUMSTANCES execute it again. Ask the user for guidance and offer them options. + + Note that ES|QL (the Elasticsearch query language, which is NOT Elasticsearch SQL, but a new piped language) is the preferred query language. + + If the user asks about a query, or ES|QL, always call the "esql" function. DO NOT UNDER ANY CIRCUMSTANCES generate ES|QL queries yourself. Even if the "recall" function was used before that, follow it up with the "esql" function.` + ); + + if (isReady) { + description += `You can use the "summarize" functions to store new information you have learned in a knowledge database. Once you have established that you did not know the answer to a question, and the user gave you this information, it's important that you create a summarisation of what you have learned and store it in the knowledge database. Don't create a new summarization if you see a similar summarization in the conversation, instead, update the existing one by re-using its ID. + + Additionally, you can use the "recall" function to retrieve relevant information from the knowledge database. + `; + + description += `Here are principles you MUST adhere to, in order: + - DO NOT make any assumptions about where and how users have stored their data. ALWAYS first call get_dataset_info function with empty string to get information about available indices. Once you know about available indices you MUST use this function again to get a list of available fields for specific index. If user provides an index name make sure its a valid index first before using it to retrieve the field list by calling this function with an empty string! + `; + registerSummarizationFunction(registrationParameters); + registerRecallFunction(registrationParameters); + registerLensFunction(registrationParameters); + } else { + description += `You do not have a working memory. Don't try to recall information via the "recall" function. If the user expects you to remember the previous conversations, tell them they can set up the knowledge base. A banner is available at the top of the conversation to set this up.`; + } + + registerElasticsearchFunction(registrationParameters); + registerEsqlFunction(registrationParameters); + registerAlertsFunction(registrationParameters); + registerGetDatasetInfoFunction(registrationParameters); + + registerContext({ + name: 'core', + description: dedent(description), + }); + }); +}; diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/lens.ts b/x-pack/plugins/observability_ai_assistant/server/functions/lens.ts new file mode 100644 index 00000000000000..62e0f98c1b65df --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/server/functions/lens.ts @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { lensFunctionDefinition } from '../../common/functions/lens'; +import { RegisterFunction } from '../service/types'; + +export function registerLensFunction({ registerFunction }: { registerFunction: RegisterFunction }) { + registerFunction(lensFunctionDefinition, async () => { + return { + content: {}, + }; + }); +} diff --git a/x-pack/plugins/observability_ai_assistant/public/functions/recall.ts b/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts similarity index 86% rename from x-pack/plugins/observability_ai_assistant/public/functions/recall.ts rename to x-pack/plugins/observability_ai_assistant/server/functions/recall.ts index af825a53b7cb47..bd71fa11f1dc27 100644 --- a/x-pack/plugins/observability_ai_assistant/public/functions/recall.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts @@ -7,15 +7,16 @@ import type { Serializable } from '@kbn/utility-types'; import { omit } from 'lodash'; -import { MessageRole, RegisterFunctionDefinition } from '../../common/types'; -import type { ObservabilityAIAssistantService } from '../types'; +import { MessageRole } from '../../common'; +import { ObservabilityAIAssistantClient } from '../service/client'; +import { RegisterFunction } from '../service/types'; export function registerRecallFunction({ - service, + client, registerFunction, }: { - service: ObservabilityAIAssistantService; - registerFunction: RegisterFunctionDefinition; + client: ObservabilityAIAssistantClient; + registerFunction: RegisterFunction; }) { registerFunction( { @@ -79,15 +80,11 @@ export function registerRecallFunction({ const queriesWithUserPrompt = userPrompt ? [userPrompt, ...queries] : queries; - return service - .callApi('POST /internal/observability_ai_assistant/functions/recall', { - params: { - body: { - queries: queriesWithUserPrompt, - contexts, - }, - }, - signal, + return client + .recall({ + queries: queriesWithUserPrompt, + contexts, + // signal, }) .then((response): { content: Serializable } => ({ content: response.entries.map((entry) => diff --git a/x-pack/plugins/observability_ai_assistant/public/functions/summarize.ts b/x-pack/plugins/observability_ai_assistant/server/functions/summarize.ts similarity index 82% rename from x-pack/plugins/observability_ai_assistant/public/functions/summarize.ts rename to x-pack/plugins/observability_ai_assistant/server/functions/summarize.ts index 14f637591e613f..857afc5f5980e1 100644 --- a/x-pack/plugins/observability_ai_assistant/public/functions/summarize.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/summarize.ts @@ -5,16 +5,12 @@ * 2.0. */ -import type { RegisterFunctionDefinition } from '../../common/types'; -import type { ObservabilityAIAssistantService } from '../types'; +import type { FunctionRegistrationParameters } from '.'; export function registerSummarizationFunction({ - service, + client, registerFunction, -}: { - service: ObservabilityAIAssistantService; - registerFunction: RegisterFunctionDefinition; -}) { +}: FunctionRegistrationParameters) { registerFunction( { name: 'summarize', @@ -65,19 +61,17 @@ export function registerSummarizationFunction({ { arguments: { id, text, is_correction: isCorrection, confidence, public: isPublic } }, signal ) => { - return service - .callApi('POST /internal/observability_ai_assistant/functions/summarize', { - params: { - body: { - id, - text, - is_correction: isCorrection, - confidence, - public: isPublic, - labels: {}, - }, + return client + .summarize({ + entry: { + id, + text, + is_correction: isCorrection, + confidence, + public: isPublic, + labels: {}, }, - signal, + // signal, }) .then(() => ({ content: { diff --git a/x-pack/plugins/observability_ai_assistant/server/index.ts b/x-pack/plugins/observability_ai_assistant/server/index.ts index 8660446357e34b..3db2883b915211 100644 --- a/x-pack/plugins/observability_ai_assistant/server/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/index.ts @@ -8,6 +8,11 @@ import type { PluginConfigDescriptor, PluginInitializerContext } from '@kbn/core/server'; import type { ObservabilityAIAssistantConfig } from './config'; +export type { + ObservabilityAIAssistantPluginSetup, + ObservabilityAIAssistantPluginStart, +} from './types'; + export type { ObservabilityAIAssistantServerRouteRepository } from './routes/get_global_observability_ai_assistant_route_repository'; import { config as configSchema } from './config'; diff --git a/x-pack/plugins/observability_ai_assistant/server/plugin.ts b/x-pack/plugins/observability_ai_assistant/server/plugin.ts index d6a256ddf6022c..833898e84d638c 100644 --- a/x-pack/plugins/observability_ai_assistant/server/plugin.ts +++ b/x-pack/plugins/observability_ai_assistant/server/plugin.ts @@ -32,6 +32,7 @@ import { ObservabilityAIAssistantPluginStartDependencies, } from './types'; import { addLensDocsToKb } from './service/kb_service/kb_docs/lens'; +import { registerFunctions } from './functions'; export class ObservabilityAIAssistantPlugin implements @@ -109,6 +110,8 @@ export class ObservabilityAIAssistantPlugin taskManager: plugins.taskManager, }); + service.registration(registerFunctions); + addLensDocsToKb({ service, logger: this.logger.get('kb').get('lens') }); registerServerRoutes({ @@ -120,7 +123,9 @@ export class ObservabilityAIAssistantPlugin }, }); - return {}; + return { + service, + }; } public start( diff --git a/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts b/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts index 90620156acf370..163b41135b5180 100644 --- a/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts +++ b/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts @@ -54,10 +54,17 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ const stream = query.stream; + const controller = new AbortController(); + + request.events.aborted$.subscribe(() => { + controller.abort(); + }); + return client.chat({ messages, connectorId, stream, + signal: controller.signal, ...(functions.length ? { functions, @@ -68,6 +75,62 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ }, }); +const chatCompleteRoute = createObservabilityAIAssistantServerRoute({ + endpoint: 'POST /internal/observability_ai_assistant/chat/complete', + options: { + tags: ['access:ai_assistant'], + }, + params: t.type({ + body: t.intersection([ + t.type({ + messages: t.array(messageRt), + connectorId: t.string, + persist: toBooleanRt, + }), + t.partial({ + conversationId: t.string, + title: t.string, + }), + ]), + }), + handler: async (resources): Promise => { + const { request, params, service } = resources; + + const client = await service.getClient({ request }); + + if (!client) { + throw notImplemented(); + } + + const { + body: { messages, connectorId, conversationId, title, persist }, + } = params; + + const controller = new AbortController(); + + request.events.aborted$.subscribe(() => { + controller.abort(); + }); + + const functionClient = await service.getFunctionClient({ + signal: controller.signal, + resources, + client, + }); + + return client.complete({ + messages, + connectorId, + conversationId, + title, + persist, + signal: controller.signal, + functionClient, + }); + }, +}); + export const chatRoutes = { ...chatRoute, + ...chatCompleteRoute, }; diff --git a/x-pack/plugins/observability_ai_assistant/server/routes/conversations/route.ts b/x-pack/plugins/observability_ai_assistant/server/routes/conversations/route.ts index 87aba7dc6757f2..b39468c3e06c09 100644 --- a/x-pack/plugins/observability_ai_assistant/server/routes/conversations/route.ts +++ b/x-pack/plugins/observability_ai_assistant/server/routes/conversations/route.ts @@ -108,37 +108,6 @@ const updateConversationRoute = createObservabilityAIAssistantServerRoute({ }, }); -const updateConversationTitleBasedOnMessages = createObservabilityAIAssistantServerRoute({ - endpoint: 'PUT /internal/observability_ai_assistant/conversation/{conversationId}/auto_title', - params: t.type({ - path: t.type({ - conversationId: t.string, - }), - body: t.type({ - connectorId: t.string, - }), - }), - options: { - tags: ['access:ai_assistant'], - }, - handler: async (resources): Promise => { - const { service, request, params } = resources; - - const client = await service.getClient({ request }); - - if (!client) { - throw notImplemented(); - } - - const conversation = await client.autoTitle({ - conversationId: params.path.conversationId, - connectorId: params.body.connectorId, - }); - - return Promise.resolve(conversation); - }, -}); - const updateConversationTitle = createObservabilityAIAssistantServerRoute({ endpoint: 'PUT /internal/observability_ai_assistant/conversation/{conversationId}/title', params: t.type({ @@ -198,7 +167,6 @@ export const conversationRoutes = { ...findConversationsRoute, ...createConversationRoute, ...updateConversationRoute, - ...updateConversationTitleBasedOnMessages, ...updateConversationTitle, ...deleteConversationRoute, }; diff --git a/x-pack/plugins/observability_ai_assistant/server/routes/functions/route.ts b/x-pack/plugins/observability_ai_assistant/server/routes/functions/route.ts index 087e8c079b5ef3..5ac1db69e82347 100644 --- a/x-pack/plugins/observability_ai_assistant/server/routes/functions/route.ts +++ b/x-pack/plugins/observability_ai_assistant/server/routes/functions/route.ts @@ -4,19 +4,45 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import datemath from '@elastic/datemath'; import { notImplemented } from '@hapi/boom'; -import { fromKueryExpression, toElasticsearchQuery } from '@kbn/es-query'; import { nonEmptyStringRt, toBooleanRt } from '@kbn/io-ts-utils'; import * as t from 'io-ts'; -import { omit } from 'lodash'; -import type { ParsedTechnicalFields } from '@kbn/rule-registry-plugin/common'; -import { - ALERT_STATUS, - ALERT_STATUS_ACTIVE, -} from '@kbn/rule-registry-plugin/common/technical_rule_data_field_names'; -import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route'; +import { ContextDefinition, FunctionDefinition } from '../../../common/types'; import type { RecalledEntry } from '../../service/kb_service'; +import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route'; + +const getFunctionsRoute = createObservabilityAIAssistantServerRoute({ + endpoint: 'GET /internal/observability_ai_assistant/functions', + options: { + tags: ['access:ai_assistant'], + }, + handler: async ( + resources + ): Promise<{ + functionDefinitions: FunctionDefinition[]; + contextDefinitions: ContextDefinition[]; + }> => { + const { service, request } = resources; + + const controller = new AbortController(); + request.events.aborted$.subscribe(() => { + controller.abort(); + }); + + const client = await service.getClient({ request }); + + const functionClient = await service.getFunctionClient({ + signal: controller.signal, + resources, + client, + }); + + return { + functionDefinitions: functionClient.getFunctions().map((fn) => fn.definition), + contextDefinitions: functionClient.getContexts(), + }; + }, +}); const functionElasticsearchRoute = createObservabilityAIAssistantServerRoute({ endpoint: 'POST /internal/observability_ai_assistant/functions/elasticsearch', @@ -55,105 +81,6 @@ const functionElasticsearchRoute = createObservabilityAIAssistantServerRoute({ }, }); -const OMITTED_ALERT_FIELDS = [ - 'tags', - 'event.action', - 'event.kind', - 'kibana.alert.rule.execution.uuid', - 'kibana.alert.rule.revision', - 'kibana.alert.rule.tags', - 'kibana.alert.rule.uuid', - 'kibana.alert.workflow_status', - 'kibana.space_ids', - 'kibana.alert.time_range', - 'kibana.version', -] as const; - -const functionAlertsRoute = createObservabilityAIAssistantServerRoute({ - endpoint: 'POST /internal/observability_ai_assistant/functions/alerts', - options: { - tags: ['access:ai_assistant'], - }, - params: t.type({ - body: t.intersection([ - t.type({ - featureIds: t.array(t.string), - start: t.string, - end: t.string, - }), - t.partial({ - filter: t.string, - includeRecovered: toBooleanRt, - }), - ]), - }), - handler: async ( - resources - ): Promise<{ - content: { - total: number; - alerts: ParsedTechnicalFields[]; - }; - }> => { - const { - featureIds, - start: startAsDatemath, - end: endAsDatemath, - filter, - includeRecovered, - } = resources.params.body; - - const racContext = await resources.context.rac; - const alertsClient = await racContext.getAlertsClient(); - - const start = datemath.parse(startAsDatemath)!.valueOf(); - const end = datemath.parse(endAsDatemath)!.valueOf(); - - const kqlQuery = !filter ? [] : [toElasticsearchQuery(fromKueryExpression(filter))]; - - const response = await alertsClient.find({ - featureIds, - - query: { - bool: { - filter: [ - { - range: { - '@timestamp': { - gte: start, - lte: end, - }, - }, - }, - ...kqlQuery, - ...(!includeRecovered - ? [ - { - term: { - [ALERT_STATUS]: ALERT_STATUS_ACTIVE, - }, - }, - ] - : []), - ], - }, - }, - }); - - // trim some fields - const alerts = response.hits.hits.map((hit) => - omit(hit._source, ...OMITTED_ALERT_FIELDS) - ) as unknown as ParsedTechnicalFields[]; - - return { - content: { - total: (response.hits as { total: { value: number } }).total.value, - alerts, - }, - }; - }, -}); - const functionRecallRoute = createObservabilityAIAssistantServerRoute({ endpoint: 'POST /internal/observability_ai_assistant/functions/recall', params: t.type({ @@ -349,11 +276,11 @@ const functionGetDatasetInfoRoute = createObservabilityAIAssistantServerRoute({ }); export const functionRoutes = { + ...getFunctionsRoute, ...functionElasticsearchRoute, ...functionRecallRoute, ...functionSummariseRoute, ...setupKnowledgeBaseRoute, ...getKnowledgeBaseStatus, - ...functionAlertsRoute, ...functionGetDatasetInfoRoute, }; diff --git a/x-pack/plugins/observability_ai_assistant/server/routes/types.ts b/x-pack/plugins/observability_ai_assistant/server/routes/types.ts index 1766f5c2d55428..26753874d32241 100644 --- a/x-pack/plugins/observability_ai_assistant/server/routes/types.ts +++ b/x-pack/plugins/observability_ai_assistant/server/routes/types.ts @@ -8,6 +8,8 @@ import type { CustomRequestHandlerContext, KibanaRequest } from '@kbn/core/server'; import type { Logger } from '@kbn/logging'; import type { RacApiRequestHandlerContext } from '@kbn/rule-registry-plugin/server'; +import type { LicensingApiRequestHandlerContext } from '@kbn/licensing-plugin/server/types'; +import type { AlertingApiRequestHandlerContext } from '@kbn/alerting-plugin/server/types'; import type { ObservabilityAIAssistantService } from '../service'; import type { ObservabilityAIAssistantPluginSetupDependencies, @@ -16,6 +18,8 @@ import type { export type ObservabilityAIAssistantRequestHandlerContext = CustomRequestHandlerContext<{ rac: RacApiRequestHandlerContext; + licensing: LicensingApiRequestHandlerContext; + alerting: AlertingApiRequestHandlerContext; }>; export interface ObservabilityAIAssistantRouteHandlerResources { diff --git a/x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.ts new file mode 100644 index 00000000000000..e4800ca61ec849 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.ts @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +/* eslint-disable max-classes-per-file*/ + +import type { Validator, OutputUnit } from '@cfworker/json-schema'; +import { keyBy } from 'lodash'; +import type { + ContextDefinition, + ContextRegistry, + FunctionResponse, + Message, +} from '../../../common/types'; +import { filterFunctionDefinitions } from '../../../common/utils/filter_function_definitions'; +import { FunctionHandler, FunctionHandlerRegistry } from '../types'; + +export class FunctionArgsValidationError extends Error { + constructor(public readonly errors: OutputUnit[]) { + super('Function arguments are invalid'); + } +} + +export class ChatFunctionClient { + constructor( + private readonly contextRegistry: ContextRegistry, + private readonly functionRegistry: FunctionHandlerRegistry, + private readonly validators: Map + ) {} + + private validate(name: string, parameters: unknown) { + const validator = this.validators.get(name)!; + const result = validator.validate(parameters); + if (!result.valid) { + throw new FunctionArgsValidationError(result.errors); + } + } + + getContexts(): ContextDefinition[] { + return Array.from(this.contextRegistry.values()); + } + + getFunctions({ + contexts, + filter, + }: { contexts?: string[]; filter?: string } = {}): FunctionHandler[] { + const allFunctions = Array.from(this.functionRegistry.values()); + + const functionsByName = keyBy(allFunctions, (definition) => definition.definition.name); + + const matchingDefinitions = filterFunctionDefinitions({ + contexts, + filter, + definitions: allFunctions.map((fn) => fn.definition), + }); + + return matchingDefinitions.map((definition) => functionsByName[definition.name]); + } + + hasFunction(name: string): boolean { + return this.functionRegistry.has(name); + } + + async executeFunction({ + name, + args, + messages, + signal, + connectorId, + }: { + name: string; + args: string | undefined; + messages: Message[]; + signal: AbortSignal; + connectorId: string; + }): Promise { + const fn = this.functionRegistry.get(name); + + if (!fn) { + throw new Error(`Function ${name} not found`); + } + + const parsedArguments = args ? JSON.parse(args) : {}; + + this.validate(name, parsedArguments); + + return await fn.respond({ arguments: parsedArguments, messages, connectorId }, signal); + } +} diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/handle_llm_response.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/handle_llm_response.ts new file mode 100644 index 00000000000000..39d9b1a87c3756 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/handle_llm_response.ts @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Observable } from 'rxjs'; +import { v4 } from 'uuid'; +import { Message, MessageRole } from '../../../common'; +import { + StreamingChatResponseEvent, + StreamingChatResponseEventType, +} from '../../../common/conversation_complete'; +import type { CreateChatCompletionResponseChunk } from '../../../common/types'; + +export function handleLlmResponse({ + signal, + write, + source$, +}: { + signal: AbortSignal; + write: (event: StreamingChatResponseEvent) => Promise; + source$: Observable; +}): Promise<{ id: string; message: Message['message'] }> { + return new Promise<{ message: Message['message']; id: string }>((resolve, reject) => { + const message = { + content: '', + role: MessageRole.Assistant, + function_call: { name: '', arguments: '', trigger: MessageRole.Assistant as const }, + }; + + const id = v4(); + const subscription = source$.subscribe({ + next: (chunk) => { + const delta = chunk.choices[0].delta; + + message.content += delta.content || ''; + message.function_call.name += delta.function_call?.name || ''; + message.function_call.arguments += delta.function_call?.arguments || ''; + + write({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + message: delta, + id, + }); + }, + complete: () => { + resolve({ id, message }); + }, + error: (error) => { + reject(error); + }, + }); + + signal.addEventListener('abort', () => { + subscription.unsubscribe(); + reject(new Error('Request aborted')); + }); + }).then(async ({ id, message }) => { + await write({ + type: StreamingChatResponseEventType.MessageAdd, + message: { + '@timestamp': new Date().toISOString(), + message, + }, + id, + }); + return { id, message }; + }); +} diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts new file mode 100644 index 00000000000000..bafc3e5cfc643d --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts @@ -0,0 +1,1127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client'; +import type { ElasticsearchClient, Logger } from '@kbn/core/server'; +import type { DeeplyMockedKeys } from '@kbn/utility-types-jest'; +import { merge } from 'lodash'; +import { Subject } from 'rxjs'; +import { PassThrough, type Readable } from 'stream'; +import { finished } from 'stream/promises'; +import { ObservabilityAIAssistantClient } from '.'; +import { createResourceNamesMap } from '..'; +import { MessageRole, type Message } from '../../../common'; +import { StreamingChatResponseEventType } from '../../../common/conversation_complete'; +import type { CreateChatCompletionResponseChunk } from '../../../public/types'; +import type { ChatFunctionClient } from '../chat_function_client'; +import type { KnowledgeBaseService } from '../kb_service'; + +type ChunkDelta = CreateChatCompletionResponseChunk['choices'][number]['delta']; + +type LlmSimulator = ReturnType; + +const nextTick = () => { + return new Promise(process.nextTick); +}; + +const waitForNextWrite = async (stream: Readable): Promise => { + // this will fire before the client's internal write() promise is + // resolved + await new Promise((resolve) => stream.once('data', resolve)); + // so we wait another tick to let the client move to the next step + await nextTick(); +}; + +function createLlmSimulator() { + const stream = new PassThrough(); + return { + stream, + next: async (msg: ChunkDelta) => { + const chunk: CreateChatCompletionResponseChunk = { + created: 0, + id: '', + model: 'gpt-4', + object: 'chat.completion.chunk', + choices: [ + { + delta: msg, + }, + ], + }; + await new Promise((resolve, reject) => { + stream.write(`data: ${JSON.stringify(chunk)}\n`, undefined, (err) => { + return err ? reject(err) : resolve(); + }); + }); + }, + complete: async () => { + if (stream.destroyed) { + throw new Error('Stream is already destroyed'); + } + await new Promise((resolve) => stream.write('data: [DONE]', () => stream.end(resolve))); + }, + error: (error: Error) => { + stream.destroy(error); + }, + }; +} + +describe('Observability AI Assistant service', () => { + let client: ObservabilityAIAssistantClient; + + const actionsClientMock: DeeplyMockedKeys = { + execute: jest.fn(), + } as any; + + const esClientMock: DeeplyMockedKeys = { + search: jest.fn(), + index: jest.fn(), + update: jest.fn(), + } as any; + + const knowledgeBaseServiceMock: DeeplyMockedKeys = { + recall: jest.fn(), + } as any; + + const loggerMock: DeeplyMockedKeys = { + log: jest.fn(), + error: jest.fn(), + } as any; + + const functionClientMock: DeeplyMockedKeys = { + executeFunction: jest.fn(), + getFunctions: jest.fn().mockReturnValue([]), + hasFunction: jest.fn().mockImplementation((name) => { + return name !== 'recall'; + }), + } as any; + + let llmSimulator: LlmSimulator; + + function createClient() { + jest.clearAllMocks(); + + return new ObservabilityAIAssistantClient({ + actionsClient: actionsClientMock, + esClient: esClientMock, + knowledgeBaseService: knowledgeBaseServiceMock, + logger: loggerMock, + namespace: 'default', + resources: createResourceNamesMap(), + user: { + name: 'johndoe', + }, + }); + } + + function system(content: string | Omit): Message { + return merge( + { + '@timestamp': new Date().toString(), + message: { + role: MessageRole.System, + }, + }, + typeof content === 'string' ? { message: { content } } : content + ); + } + + function user(content: string | Omit): Message { + return merge( + { + '@timestamp': new Date().toString(), + message: { + role: MessageRole.User, + }, + }, + typeof content === 'string' ? { message: { content } } : content + ); + } + + describe('when completing a conversation without an initial conversation id', () => { + let stream: Readable; + + let titleLlmPromiseResolve: (title: string) => void; + let titleLlmPromiseReject: Function; + + beforeEach(async () => { + client = createClient(); + actionsClientMock.execute + .mockImplementationOnce(async () => { + llmSimulator = createLlmSimulator(); + return { + actionId: '', + status: 'ok', + data: llmSimulator.stream, + }; + }) + .mockImplementationOnce(() => { + return new Promise((resolve, reject) => { + titleLlmPromiseResolve = (title: string) => { + const response = { + object: 'chat.completion', + choices: [ + { + message: { + role: MessageRole.Assistant, + content: title, + }, + }, + ], + }; + resolve({ + actionId: '', + status: 'ok', + data: response, + }); + }; + titleLlmPromiseReject = reject; + }); + }); + + stream = await client.complete({ + connectorId: 'foo', + messages: [system('This is a system message'), user('How many alerts do I have?')], + functionClient: functionClientMock, + signal: new AbortController().signal, + persist: true, + }); + }); + + describe('when streaming the response from the LLM', () => { + let dataHandler: jest.Mock; + + beforeEach(async () => { + dataHandler = jest.fn(); + + stream.on('data', dataHandler); + + await llmSimulator.next({ content: 'Hello' }); + }); + + it('calls the actions client with the messages', () => { + expect(actionsClientMock.execute.mock.calls[0]).toEqual([ + { + actionId: 'foo', + params: { + subAction: 'stream', + subActionParams: { + body: expect.any(String), + stream: true, + }, + }, + }, + ]); + }); + + it('calls the llm again to generate a new title', () => { + expect(actionsClientMock.execute.mock.calls[1]).toEqual([ + { + actionId: 'foo', + params: { + subAction: 'run', + subActionParams: { + body: expect.any(String), + }, + }, + }, + ]); + }); + + it('incrementally streams the response to the client', () => { + expect(dataHandler).toHaveBeenCalledTimes(1); + + expect(JSON.parse(dataHandler.mock.calls[0])).toEqual({ + id: expect.any(String), + message: { + content: 'Hello', + }, + type: StreamingChatResponseEventType.ChatCompletionChunk, + }); + }); + + describe('after the LLM errors out', () => { + it('adds an error to the stream and closes it', () => {}); + }); + + describe('when generating a title fails', () => { + beforeEach(async () => { + titleLlmPromiseReject(new Error('Failed generating title')); + + await nextTick(); + + await llmSimulator.complete(); + + await finished(stream); + }); + + it('falls back to the default title', () => { + expect(JSON.parse(dataHandler.mock.calls[2])).toEqual({ + conversation: { + title: 'New conversation', + id: expect.any(String), + last_updated: expect.any(String), + }, + type: StreamingChatResponseEventType.ConversationCreate, + }); + + expect(loggerMock.error).toHaveBeenCalled(); + }); + }); + + describe('after completing the response from the LLM', () => { + beforeEach(async () => { + await llmSimulator.next({ content: ' again' }); + + titleLlmPromiseResolve('An auto-generated title'); + + await nextTick(); + + await llmSimulator.complete(); + + await finished(stream); + }); + it('adds the completed message to the stream', () => { + expect(JSON.parse(dataHandler.mock.calls[1])).toEqual({ + id: expect.any(String), + message: { + content: ' again', + }, + type: StreamingChatResponseEventType.ChatCompletionChunk, + }); + + expect(JSON.parse(dataHandler.mock.calls[2])).toEqual({ + id: expect.any(String), + message: { + '@timestamp': expect.any(String), + message: { + content: 'Hello again', + role: MessageRole.Assistant, + function_call: { + arguments: '', + name: '', + trigger: MessageRole.Assistant, + }, + }, + }, + type: StreamingChatResponseEventType.MessageAdd, + }); + }); + + it('creates a new conversation with the automatically generated title', () => { + expect(JSON.parse(dataHandler.mock.calls[3])).toEqual({ + conversation: { + title: 'An auto-generated title', + id: expect.any(String), + last_updated: expect.any(String), + }, + type: StreamingChatResponseEventType.ConversationCreate, + }); + + expect(esClientMock.index).toHaveBeenCalledWith({ + index: '.kibana-observability-ai-assistant-conversations', + refresh: true, + document: { + '@timestamp': expect.any(String), + conversation: { + id: expect.any(String), + last_updated: expect.any(String), + title: 'An auto-generated title', + }, + labels: {}, + numeric_labels: {}, + public: false, + namespace: 'default', + user: { + name: 'johndoe', + }, + messages: [ + { + '@timestamp': expect.any(String), + message: { + content: 'This is a system message', + role: MessageRole.System, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: 'How many alerts do I have?', + role: MessageRole.User, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: 'Hello again', + role: MessageRole.Assistant, + function_call: { + name: '', + arguments: '', + trigger: MessageRole.Assistant, + }, + }, + }, + ], + }, + }); + }); + }); + }); + }); + + describe('when completing a conversation with an initial conversation id', () => { + let stream: Readable; + + let dataHandler: jest.Mock; + + beforeEach(async () => { + client = createClient(); + actionsClientMock.execute.mockImplementationOnce(async () => { + llmSimulator = createLlmSimulator(); + return { + actionId: '', + status: 'ok', + data: llmSimulator.stream, + }; + }); + + esClientMock.search.mockImplementation(async () => { + return { + hits: { + hits: [ + { + _id: 'my-es-document-id', + _index: '.kibana-observability-ai-assistant-conversations', + _source: { + '@timestamp': new Date().toISOString(), + conversation: { + id: 'my-conversation-id', + title: 'My stored conversation', + last_updated: new Date().toISOString(), + }, + labels: {}, + numeric_labels: {}, + public: false, + messages: [ + system('This is a system message'), + user('How many alerts do I have?'), + ], + }, + }, + ], + }, + } as any; + }); + + esClientMock.update.mockImplementationOnce(async () => { + return {} as any; + }); + + stream = await client.complete({ + connectorId: 'foo', + messages: [system('This is a system message'), user('How many alerts do I have?')], + functionClient: functionClientMock, + signal: new AbortController().signal, + conversationId: 'my-conversation-id', + persist: true, + }); + + dataHandler = jest.fn(); + + stream.on('data', dataHandler); + + await llmSimulator.next({ content: 'Hello' }); + + await llmSimulator.complete(); + + await finished(stream); + }); + + it('updates the conversation', () => { + expect(JSON.parse(dataHandler.mock.calls[2])).toEqual({ + conversation: { + title: 'My stored conversation', + id: expect.any(String), + last_updated: expect.any(String), + }, + type: StreamingChatResponseEventType.ConversationUpdate, + }); + + expect(esClientMock.update).toHaveBeenCalledWith({ + refresh: true, + index: '.kibana-observability-ai-assistant-conversations', + id: 'my-es-document-id', + doc: { + '@timestamp': expect.any(String), + conversation: { + id: expect.any(String), + last_updated: expect.any(String), + title: 'My stored conversation', + }, + labels: {}, + numeric_labels: {}, + public: false, + namespace: 'default', + user: { + name: 'johndoe', + }, + messages: [ + { + '@timestamp': expect.any(String), + message: { + content: 'This is a system message', + role: MessageRole.System, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: 'How many alerts do I have?', + role: MessageRole.User, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: 'Hello', + role: MessageRole.Assistant, + function_call: { + name: '', + arguments: '', + trigger: MessageRole.Assistant, + }, + }, + }, + ], + }, + }); + }); + }); + + describe('when the LLM response fails', () => { + let stream: Readable; + + let dataHandler: jest.Mock; + + beforeEach(async () => { + client = createClient(); + actionsClientMock.execute.mockImplementationOnce(async () => { + llmSimulator = createLlmSimulator(); + return { + actionId: '', + status: 'ok', + data: llmSimulator.stream, + }; + }); + + stream = await client.complete({ + connectorId: 'foo', + messages: [system('This is a system message'), user('How many alerts do I have?')], + functionClient: functionClientMock, + signal: new AbortController().signal, + title: 'My predefined title', + persist: true, + }); + + dataHandler = jest.fn(); + + stream.on('data', dataHandler); + + await llmSimulator.next({ content: 'Hello' }); + + await new Promise((resolve) => + llmSimulator.stream.write( + `data: ${JSON.stringify({ + error: { + message: 'Connection unexpectedly closed', + }, + })}\n`, + resolve + ) + ); + + await finished(stream); + + await llmSimulator.complete(); + }); + + it('ends the stream and writes an error', async () => { + expect(JSON.parse(dataHandler.mock.calls[1])).toEqual({ + error: { + message: 'Connection unexpectedly closed', + stack: expect.any(String), + }, + type: StreamingChatResponseEventType.ConversationCompletionError, + }); + }); + + it('does not create or update the conversation', async () => { + expect(esClientMock.index).not.toHaveBeenCalled(); + expect(esClientMock.update).not.toHaveBeenCalled(); + }); + }); + + describe('when the assistant answers with a function request', () => { + let stream: Readable; + + let dataHandler: jest.Mock; + + let respondFn: jest.Mock; + + let fnResponseResolve: (data: unknown) => void; + + let fnResponseReject: (error: Error) => void; + + beforeEach(async () => { + client = createClient(); + actionsClientMock.execute.mockImplementationOnce(async () => { + llmSimulator = createLlmSimulator(); + return { + actionId: '', + status: 'ok', + data: llmSimulator.stream, + }; + }); + + respondFn = jest.fn(); + + functionClientMock.getFunctions.mockImplementation(() => [ + { + definition: { + name: 'myFunction', + contexts: ['core'], + description: 'my-description', + descriptionForUser: '', + parameters: { + type: 'object', + additionalProperties: false, + properties: { + foo: { + type: 'string', + enum: ['bar'], + }, + }, + required: ['foo'], + }, + }, + respond: respondFn, + }, + ]); + + functionClientMock.executeFunction.mockImplementationOnce(() => { + return new Promise((resolve, reject) => { + fnResponseResolve = resolve; + fnResponseReject = reject; + }); + }); + + stream = await client.complete({ + connectorId: 'foo', + messages: [system('This is a system message'), user('How many alerts do I have?')], + functionClient: functionClientMock, + signal: new AbortController().signal, + title: 'My predefined title', + persist: true, + }); + + dataHandler = jest.fn(); + + stream.on('data', dataHandler); + + await llmSimulator.next({ + content: 'Hello', + function_call: { name: 'my-function', arguments: JSON.stringify({ foo: 'bar' }) }, + }); + + const prevLlmSimulator = llmSimulator; + + actionsClientMock.execute.mockImplementationOnce(async () => { + llmSimulator = createLlmSimulator(); + return { + actionId: '', + status: 'ok', + data: llmSimulator.stream, + }; + }); + + await prevLlmSimulator.complete(); + + await waitForNextWrite(stream); + }); + + describe('while the function call is pending', () => { + it('appends the request message', async () => { + expect(JSON.parse(dataHandler.mock.lastCall!)).toEqual({ + type: StreamingChatResponseEventType.MessageAdd, + id: expect.any(String), + message: { + '@timestamp': expect.any(String), + message: { + content: 'Hello', + role: MessageRole.Assistant, + function_call: { + name: 'my-function', + arguments: JSON.stringify({ foo: 'bar' }), + trigger: MessageRole.Assistant, + }, + }, + }, + }); + }); + + it('executes the function', () => { + expect(functionClientMock.executeFunction).toHaveBeenCalledWith({ + connectorId: 'foo', + name: 'my-function', + args: JSON.stringify({ foo: 'bar' }), + signal: expect.any(AbortSignal), + messages: [ + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.System, + content: 'This is a system message', + }, + }, + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.User, + content: 'How many alerts do I have?', + }, + }, + { + '@timestamp': expect.any(String), + message: { + role: MessageRole.Assistant, + content: 'Hello', + function_call: { + name: 'my-function', + arguments: JSON.stringify({ foo: 'bar' }), + trigger: MessageRole.Assistant, + }, + }, + }, + ], + }); + }); + + afterEach(async () => { + fnResponseResolve({ content: { my: 'content' } }); + await waitForNextWrite(stream); + + await llmSimulator.complete(); + await finished(stream); + }); + }); + + describe('and the function succeeds', () => { + beforeEach(async () => { + fnResponseResolve({ content: { my: 'content' } }); + await waitForNextWrite(stream); + }); + + it('appends the function response', () => { + expect(JSON.parse(dataHandler.mock.lastCall!)).toEqual({ + type: StreamingChatResponseEventType.MessageAdd, + id: expect.any(String), + message: { + '@timestamp': expect.any(String), + message: { + role: MessageRole.User, + name: 'my-function', + content: JSON.stringify({ + my: 'content', + }), + }, + }, + }); + }); + + it('sends the function response back to the llm', () => { + expect(actionsClientMock.execute).toHaveBeenCalledTimes(2); + expect(actionsClientMock.execute.mock.lastCall!).toEqual([ + { + actionId: 'foo', + params: { + subAction: 'stream', + subActionParams: { + body: expect.any(String), + stream: true, + }, + }, + }, + ]); + }); + + describe('and the assistant replies without a function request', () => { + beforeEach(async () => { + await llmSimulator.next({ content: 'I am done here' }); + await llmSimulator.complete(); + await waitForNextWrite(stream); + + await finished(stream); + }); + + it('appends the assistant reply', () => { + expect(JSON.parse(dataHandler.mock.calls[3])).toEqual({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + id: expect.any(String), + message: { + content: 'I am done here', + }, + }); + expect(JSON.parse(dataHandler.mock.calls[4])).toEqual({ + type: StreamingChatResponseEventType.MessageAdd, + id: expect.any(String), + message: { + '@timestamp': expect.any(String), + message: { + role: MessageRole.Assistant, + content: 'I am done here', + function_call: { + name: '', + arguments: '', + trigger: MessageRole.Assistant, + }, + }, + }, + }); + }); + + it('stores the conversation', () => { + expect(JSON.parse(dataHandler.mock.lastCall!)).toEqual({ + type: StreamingChatResponseEventType.ConversationCreate, + conversation: { + id: expect.any(String), + last_updated: expect.any(String), + title: 'My predefined title', + }, + }); + + expect(esClientMock.index).toHaveBeenCalled(); + + expect((esClientMock.index.mock.lastCall![0] as any).document.messages).toEqual([ + { + '@timestamp': expect.any(String), + message: { + content: 'This is a system message', + role: MessageRole.System, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: 'How many alerts do I have?', + role: MessageRole.User, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: 'Hello', + role: MessageRole.Assistant, + function_call: { + name: 'my-function', + arguments: JSON.stringify({ foo: 'bar' }), + trigger: MessageRole.Assistant, + }, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: JSON.stringify({ + my: 'content', + }), + name: 'my-function', + role: MessageRole.User, + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: 'I am done here', + role: MessageRole.Assistant, + function_call: { + name: '', + arguments: '', + trigger: MessageRole.Assistant, + }, + }, + }, + ]); + }); + }); + }); + + describe('and the function fails', () => { + beforeEach(async () => { + fnResponseReject(new Error('Function failed')); + await waitForNextWrite(stream); + }); + + it('appends the function response', () => { + expect(JSON.parse(dataHandler.mock.lastCall!)).toEqual({ + type: StreamingChatResponseEventType.MessageAdd, + id: expect.any(String), + message: { + '@timestamp': expect.any(String), + message: { + role: MessageRole.User, + name: 'my-function', + content: JSON.stringify({ + message: 'Error: Function failed', + error: {}, + }), + }, + }, + }); + }); + + it('sends the function response back to the llm', () => { + expect(actionsClientMock.execute).toHaveBeenCalledTimes(2); + expect(actionsClientMock.execute.mock.lastCall!).toEqual([ + { + actionId: 'foo', + params: { + subAction: 'stream', + subActionParams: { + body: expect.any(String), + stream: true, + }, + }, + }, + ]); + }); + }); + + describe('and the function responds with an observable', () => { + let response$: Subject; + beforeEach(async () => { + response$ = new Subject(); + fnResponseResolve(response$); + await waitForNextWrite(stream); + }); + + it('appends the function response', () => { + expect(JSON.parse(dataHandler.mock.lastCall!)).toEqual({ + type: StreamingChatResponseEventType.MessageAdd, + id: expect.any(String), + message: { + '@timestamp': expect.any(String), + message: { + role: MessageRole.User, + name: 'my-function', + content: '{}', + }, + }, + }); + }); + + describe('if the observable completes', () => { + beforeEach(async () => { + response$.next({ + created: 0, + id: '', + model: 'gpt-4', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: 'Hello', + }, + }, + ], + }); + response$.complete(); + + await finished(stream); + }); + + it('emits a completion chunk', () => { + expect(JSON.parse(dataHandler.mock.calls[3])).toEqual({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + id: expect.any(String), + message: { + content: 'Hello', + }, + }); + }); + + it('appends the observable response', () => { + expect(JSON.parse(dataHandler.mock.calls[4])).toEqual({ + type: StreamingChatResponseEventType.MessageAdd, + id: expect.any(String), + message: { + '@timestamp': expect.any(String), + message: { + role: MessageRole.Assistant, + content: 'Hello', + function_call: { + name: '', + arguments: '', + trigger: MessageRole.Assistant, + }, + }, + }, + }); + }); + }); + + describe('if the observable errors out', () => { + let endStreamPromise: Promise; + + beforeEach(async () => { + response$.next({ + created: 0, + id: '', + model: 'gpt-4', + object: 'chat.completion.chunk', + choices: [ + { + delta: { + content: 'Hello', + }, + }, + ], + }); + response$.error(new Error('Unexpected error')); + + endStreamPromise = finished(stream); + + await endStreamPromise.catch(() => {}); + }); + + it('appends an error and fails the stream', () => { + // expect(endStreamPromise).rejects.toBeDefined(); + + expect(JSON.parse(dataHandler.mock.lastCall!)).toEqual({ + type: StreamingChatResponseEventType.ConversationCompletionError, + error: { + message: 'Unexpected error', + stack: expect.any(String), + }, + }); + }); + }); + }); + }); + + describe('when recall is available', () => { + let stream: Readable; + + let dataHandler: jest.Mock; + beforeEach(async () => { + client = createClient(); + actionsClientMock.execute.mockImplementationOnce(async () => { + llmSimulator = createLlmSimulator(); + return { + actionId: '', + status: 'ok', + data: llmSimulator.stream, + }; + }); + + functionClientMock.hasFunction.mockReturnValue(true); + + functionClientMock.executeFunction.mockImplementationOnce(async () => { + return { + content: [ + { + id: 'my_document', + text: 'My document', + }, + ], + }; + }); + + stream = await client.complete({ + connectorId: 'foo', + messages: [system('This is a system message'), user('How many alerts do I have?')], + functionClient: functionClientMock, + signal: new AbortController().signal, + persist: false, + }); + + dataHandler = jest.fn(); + + stream.on('data', dataHandler); + + await waitForNextWrite(stream); + + await llmSimulator.next({ + content: 'Hello', + }); + + await llmSimulator.complete(); + + await finished(stream); + }); + + it('appends the recall request message', () => { + expect(JSON.parse(dataHandler.mock.calls[0]!)).toEqual({ + type: StreamingChatResponseEventType.MessageAdd, + id: expect.any(String), + message: { + '@timestamp': expect.any(String), + message: { + content: '', + role: MessageRole.Assistant, + function_call: { + name: 'recall', + arguments: JSON.stringify({ queries: [], contexts: [] }), + trigger: MessageRole.Assistant, + }, + }, + }, + }); + }); + + it('appends the recall response', () => { + expect(JSON.parse(dataHandler.mock.calls[1]!)).toEqual({ + type: StreamingChatResponseEventType.MessageAdd, + id: expect.any(String), + message: { + '@timestamp': expect.any(String), + message: { + content: JSON.stringify([{ id: 'my_document', text: 'My document' }]), + role: MessageRole.User, + name: 'recall', + }, + }, + }); + }); + + it('appends the response from the LLM', () => { + expect(JSON.parse(dataHandler.mock.calls[2]!)).toEqual({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + id: expect.any(String), + message: { + content: 'Hello', + }, + }); + + expect(JSON.parse(dataHandler.mock.calls[3]!)).toEqual({ + type: StreamingChatResponseEventType.MessageAdd, + id: expect.any(String), + message: { + '@timestamp': expect.any(String), + message: { + content: 'Hello', + function_call: { + name: '', + arguments: '', + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }, + }); + }); + }); +}); diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts index 2332f63a54c78e..c64badcece9dc4 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts @@ -10,16 +10,24 @@ import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { ElasticsearchClient } from '@kbn/core/server'; import type { Logger } from '@kbn/logging'; import type { PublicMethodsOf } from '@kbn/utility-types'; -import { compact, isEmpty, merge, omit } from 'lodash'; +import { compact, isEmpty, last, merge, omit, pick } from 'lodash'; import type { ChatCompletionFunctions, ChatCompletionRequestMessage, CreateChatCompletionRequest, CreateChatCompletionResponse, } from 'openai'; +import { isObservable } from 'rxjs'; import { PassThrough, Readable } from 'stream'; import { v4 } from 'uuid'; import { + ConversationNotFoundError, + isChatCompletionError, + StreamingChatResponseEventType, + type StreamingChatResponseEvent, +} from '../../../common/conversation_complete'; +import { + FunctionResponse, MessageRole, type CompatibleJSONSchema, type Conversation, @@ -28,9 +36,13 @@ import { type KnowledgeBaseEntry, type Message, } from '../../../common/types'; +import { processOpenAiStream } from '../../../common/utils/process_openai_stream'; +import type { ChatFunctionClient } from '../chat_function_client'; import type { KnowledgeBaseService, RecalledEntry } from '../kb_service'; import type { ObservabilityAIAssistantResourceNames } from '../types'; import { getAccessQuery } from '../util/get_access_query'; +import { streamIntoObservable } from '../util/stream_into_observable'; +import { handleLlmResponse } from './handle_llm_response'; export class ObservabilityAIAssistantClient { constructor( @@ -104,18 +116,263 @@ export class ObservabilityAIAssistantClient { }); }; + complete = async ( + params: { + messages: Message[]; + connectorId: string; + signal: AbortSignal; + functionClient: ChatFunctionClient; + persist: boolean; + } & ({ conversationId: string } | { title?: string }) + ) => { + const stream = new PassThrough(); + + const { messages, connectorId, signal, functionClient, persist } = params; + + let conversationId: string = ''; + let title: string = ''; + if ('conversationId' in params) { + conversationId = params.conversationId; + } + + if ('title' in params) { + title = params.title || ''; + } + + function write(event: StreamingChatResponseEvent) { + if (stream.destroyed) { + return Promise.resolve(); + } + + return new Promise((resolve, reject) => { + stream.write(`${JSON.stringify(event)}\n`, 'utf-8', (err) => { + if (err) { + reject(err); + return; + } + resolve(); + }); + }); + } + + function fail(error: Error) { + const code = isChatCompletionError(error) ? error.code : undefined; + write({ + type: StreamingChatResponseEventType.ConversationCompletionError, + error: { + message: error.message, + stack: error.stack, + code, + }, + }).finally(() => { + stream.end(); + }); + } + + const next = async (nextMessages: Message[]): Promise => { + const lastMessage = last(nextMessages); + + const isUserMessage = lastMessage?.message.role === MessageRole.User; + + const isUserMessageWithoutFunctionResponse = isUserMessage && !lastMessage?.message.name; + + const recallFirst = + isUserMessageWithoutFunctionResponse && functionClient.hasFunction('recall'); + + const isAssistantMessageWithFunctionRequest = + lastMessage?.message.role === MessageRole.Assistant && + !!lastMessage?.message.function_call?.name; + + if (recallFirst) { + const addedMessage = { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + content: '', + function_call: { + name: 'recall', + arguments: JSON.stringify({ + queries: [], + contexts: [], + }), + trigger: MessageRole.Assistant as const, + }, + }, + }; + await write({ + type: StreamingChatResponseEventType.MessageAdd, + id: v4(), + message: addedMessage, + }); + return await next(nextMessages.concat(addedMessage)); + } else if (isUserMessage) { + const { message } = await handleLlmResponse({ + signal, + write, + source$: streamIntoObservable( + await this.chat({ + messages: nextMessages, + connectorId, + stream: true, + signal, + functions: functionClient + .getFunctions() + .map((fn) => pick(fn.definition, 'name', 'description', 'parameters')), + }) + ).pipe(processOpenAiStream()), + }); + return await next(nextMessages.concat({ message, '@timestamp': new Date().toISOString() })); + } + + if (isAssistantMessageWithFunctionRequest) { + const functionResponse = await functionClient + .executeFunction({ + connectorId, + name: lastMessage.message.function_call!.name, + messages: nextMessages, + args: lastMessage.message.function_call!.arguments, + signal, + }) + .catch((error): FunctionResponse => { + return { + content: { + message: error.toString(), + error, + }, + }; + }); + + if (signal.aborted) { + return; + } + + const functionResponseIsObservable = isObservable(functionResponse); + + const functionResponseMessage = { + '@timestamp': new Date().toISOString(), + message: { + name: lastMessage.message.function_call!.name, + ...(functionResponseIsObservable + ? { content: '{}' } + : { + content: JSON.stringify(functionResponse.content || {}), + data: functionResponse.data ? JSON.stringify(functionResponse.data) : undefined, + }), + role: MessageRole.User, + }, + }; + + nextMessages = nextMessages.concat(functionResponseMessage); + await write({ + type: StreamingChatResponseEventType.MessageAdd, + message: functionResponseMessage, + id: v4(), + }); + + if (functionResponseIsObservable) { + const { message } = await handleLlmResponse({ + signal, + write, + source$: functionResponse, + }); + return await next( + nextMessages.concat({ '@timestamp': new Date().toISOString(), message }) + ); + } + return await next(nextMessages); + } + + if (!persist) { + stream.end(); + return; + } + + // store the updated conversation and close the stream + if (conversationId) { + const conversation = await this.getConversationWithMetaFields(conversationId); + if (!conversation) { + throw new ConversationNotFoundError(); + } + + if (signal.aborted) { + return; + } + + const updatedConversation = await this.update( + merge({}, conversation._source, { messages: nextMessages }) + ); + await write({ + type: StreamingChatResponseEventType.ConversationUpdate, + conversation: updatedConversation.conversation, + }); + } else { + const generatedTitle = await titlePromise; + if (signal.aborted) { + return; + } + + const conversation = await this.create({ + '@timestamp': new Date().toISOString(), + conversation: { + title: generatedTitle || title || 'New conversation', + }, + messages: nextMessages, + labels: {}, + numeric_labels: {}, + public: false, + }); + await write({ + type: StreamingChatResponseEventType.ConversationCreate, + conversation: conversation.conversation, + }); + } + + stream.end(); + }; + + next(messages).catch((error) => { + if (!signal.aborted) { + this.dependencies.logger.error(error); + } + fail(error); + }); + + const titlePromise = + !conversationId && !title && persist + ? this.getGeneratedTitle({ + messages, + connectorId, + signal, + }).catch((error) => { + this.dependencies.logger.error( + 'Could not generate title, falling back to default title' + ); + this.dependencies.logger.error(error); + return Promise.resolve(undefined); + }) + : Promise.resolve(undefined); + + signal.addEventListener('abort', () => { + stream.end(); + }); + + return stream; + }; + chat = async ({ messages, connectorId, functions, functionCall, stream = true, + signal, }: { messages: Message[]; connectorId: string; functions?: Array<{ name: string; description: string; parameters: CompatibleJSONSchema }>; functionCall?: string; stream?: TStream; + signal: AbortSignal; }): Promise => { const messagesForOpenAI: ChatCompletionRequestMessage[] = compact( messages @@ -174,8 +431,8 @@ export class ObservabilityAIAssistantClient { const request: Omit & { model?: string } = { messages: messagesForOpenAI, - stream: true, - functions: functionsForOpenAI, + ...(stream ? { stream: true } : {}), + ...(!!functions?.length ? { functions: functionsForOpenAI } : {}), temperature: 0, function_call: functionCall ? { name: functionCall } : undefined, }; @@ -196,9 +453,15 @@ export class ObservabilityAIAssistantClient { } const response = stream - ? ((executeResult.data as Readable).pipe(new PassThrough()) as Readable) + ? (executeResult.data as Readable).pipe(new PassThrough()) : (executeResult.data as CreateChatCompletionResponse); + if (response instanceof PassThrough) { + signal.addEventListener('abort', () => { + response.end(); + }); + } + return response as any; }; @@ -250,43 +513,34 @@ export class ObservabilityAIAssistantClient { return updatedConversation; }; - autoTitle = async ({ - conversationId, + getGeneratedTitle = async ({ + messages, connectorId, + signal, }: { - conversationId: string; + messages: Message[]; connectorId: string; + signal: AbortSignal; }) => { - const document = await this.getConversationWithMetaFields(conversationId); - if (!document) { - throw notFound(); - } - - const conversation = await this.get(conversationId); - - if (!conversation) { - throw notFound(); - } - const response = await this.chat({ messages: [ { '@timestamp': new Date().toISOString(), message: { role: MessageRole.Assistant, - content: conversation.messages.slice(1).reduce((acc, curr) => { + content: messages.slice(1).reduce((acc, curr) => { return `${acc} ${curr.message.role}: ${curr.message.content}`; - }, 'You are a helpful assistant for Elastic Observability. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on this content: '), + }, 'You are a helpful assistant for Elastic Observability. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Here is the content:'), }, }, ], connectorId, stream: false, + signal, }); if ('object' in response && response.object === 'chat.completion') { - const input = - response.choices[0].message?.content || `Conversation on ${conversation['@timestamp']}`; + const input = response.choices[0].message?.content || ''; // This regular expression captures a string enclosed in single or double quotes. // It extracts the string content without the quotes. @@ -296,19 +550,9 @@ export class ObservabilityAIAssistantClient { // - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes const match = input.match(/^["']?([^"']+)["']?$/); const title = match ? match[1] : input; - - const updatedConversation: Conversation = merge( - {}, - conversation, - { conversation: { title } }, - this.getConversationUpdateValues(new Date().toISOString()) - ); - - await this.setTitle({ conversationId, title }); - - return updatedConversation; + return title; } - return conversation; + return undefined; }; setTitle = async ({ conversationId, title }: { conversationId: string; title: string }) => { diff --git a/x-pack/plugins/observability_ai_assistant/server/service/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/index.ts index ab37c59563713e..0b085f7b735a3d 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/index.ts @@ -5,6 +5,7 @@ * 2.0. */ +import type { Validator as IValidator } from '@cfworker/json-schema'; import * as Boom from '@hapi/boom'; import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server/plugin'; import { createConcreteWriteIndex, getDataStreamAdapter } from '@kbn/alerting-plugin/server'; @@ -13,39 +14,28 @@ import type { SecurityPluginStart } from '@kbn/security-plugin/server'; import { getSpaceIdFromPath } from '@kbn/spaces-plugin/common'; import type { TaskManagerSetupContract } from '@kbn/task-manager-plugin/server'; import { once } from 'lodash'; +import { ContextRegistry, RegisterContextDefinition } from '../../common/types'; import type { ObservabilityAIAssistantPluginStartDependencies } from '../types'; +import { ChatFunctionClient } from './chat_function_client'; import { ObservabilityAIAssistantClient } from './client'; import { conversationComponentTemplate } from './conversation_component_template'; import { kbComponentTemplate } from './kb_component_template'; import { KnowledgeBaseEntryOperationType, KnowledgeBaseService } from './kb_service'; -import type { ObservabilityAIAssistantResourceNames } from './types'; +import type { + ChatRegistrationFunction, + FunctionHandlerRegistry, + ObservabilityAIAssistantResourceNames, + RegisterFunction, + RespondFunctionResources, +} from './types'; import { splitKbText } from './util/split_kb_text'; function getResourceName(resource: string) { return `.kibana-observability-ai-assistant-${resource}`; } -export const ELSER_MODEL_ID = '.elser_model_2'; - -export const INDEX_QUEUED_DOCUMENTS_TASK_ID = 'observabilityAIAssistant:indexQueuedDocumentsTask'; - -export const INDEX_QUEUED_DOCUMENTS_TASK_TYPE = INDEX_QUEUED_DOCUMENTS_TASK_ID + 'Type'; - -type KnowledgeBaseEntryRequest = { id: string; labels?: Record } & ( - | { - text: string; - } - | { - texts: string[]; - } -); - -export class ObservabilityAIAssistantService { - private readonly core: CoreSetup; - private readonly logger: Logger; - private kbService?: KnowledgeBaseService; - - private readonly resourceNames: ObservabilityAIAssistantResourceNames = { +export function createResourceNamesMap() { + return { componentTemplate: { conversations: getResourceName('component-template-conversations'), kb: getResourceName('component-template-kb'), @@ -66,6 +56,31 @@ export class ObservabilityAIAssistantService { kb: getResourceName('kb-ingest-pipeline'), }, }; +} + +export const ELSER_MODEL_ID = '.elser_model_2'; + +export const INDEX_QUEUED_DOCUMENTS_TASK_ID = 'observabilityAIAssistant:indexQueuedDocumentsTask'; + +export const INDEX_QUEUED_DOCUMENTS_TASK_TYPE = INDEX_QUEUED_DOCUMENTS_TASK_ID + 'Type'; + +type KnowledgeBaseEntryRequest = { id: string; labels?: Record } & ( + | { + text: string; + } + | { + texts: string[]; + } +); + +export class ObservabilityAIAssistantService { + private readonly core: CoreSetup; + private readonly logger: Logger; + private kbService?: KnowledgeBaseService; + + private readonly resourceNames: ObservabilityAIAssistantResourceNames = createResourceNamesMap(); + + private readonly registrations: ChatRegistrationFunction[] = []; constructor({ logger, @@ -99,6 +114,12 @@ export class ObservabilityAIAssistantService { }); } + getKnowledgeBaseStatus() { + return this.init().then(() => { + return this.kbService!.status(); + }); + } + init = once(async () => { try { const [coreStart, pluginsStart] = await this.core.getStartServices(); @@ -223,13 +244,18 @@ export class ObservabilityAIAssistantService { }: { request: KibanaRequest; }): Promise { + const controller = new AbortController(); + + request.events.aborted$.subscribe(() => { + controller.abort(); + }); + const [_, [coreStart, plugins]] = await Promise.all([ this.init(), this.core.getStartServices() as Promise< [CoreStart, { security: SecurityPluginStart; actions: ActionsPluginStart }, unknown] >, ]); - const user = plugins.security.authc.getCurrentUser(request); if (!user) { @@ -254,6 +280,47 @@ export class ObservabilityAIAssistantService { }); } + async getFunctionClient({ + signal, + resources, + client, + }: { + signal: AbortSignal; + resources: RespondFunctionResources; + client: ObservabilityAIAssistantClient; + }): Promise { + const contextRegistry: ContextRegistry = new Map(); + const functionHandlerRegistry: FunctionHandlerRegistry = new Map(); + + // const Validator = await import('@cfworker/json-schema').then((m) => m.Validator); + + const validators = new Map(); + + const registerContext: RegisterContextDefinition = (context) => { + contextRegistry.set(context.name, context); + }; + + const registerFunction: RegisterFunction = (definition, respond) => { + validators.set( + definition.name, + // new Validator(definition.parameters as Schema, '2020-12', true) + { + validate: () => { + return { valid: true, errors: [] }; + }, + } as unknown as IValidator + ); + functionHandlerRegistry.set(definition.name, { definition, respond }); + }; + await Promise.all( + this.registrations.map((fn) => + fn({ signal, registerContext, registerFunction, resources, client }) + ) + ); + + return new ChatFunctionClient(contextRegistry, functionHandlerRegistry, validators); + } + addToKnowledgeBase(entries: KnowledgeBaseEntryRequest[]): void { this.init() .then(() => { @@ -306,4 +373,8 @@ export class ObservabilityAIAssistantService { }) ); } + + registration(fn: ChatRegistrationFunction) { + this.registrations.push(fn); + } } diff --git a/x-pack/plugins/observability_ai_assistant/server/service/types.ts b/x-pack/plugins/observability_ai_assistant/server/service/types.ts index 39c2b29dbb0269..d89bfd546c7023 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/types.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/types.ts @@ -5,6 +5,54 @@ * 2.0. */ +import type { FromSchema } from 'json-schema-to-ts'; +import type { + CompatibleJSONSchema, + FunctionDefinition, + FunctionResponse, + Message, + RegisterContextDefinition, +} from '../../common/types'; +import type { ObservabilityAIAssistantRouteHandlerResources } from '../routes/types'; +import type { ObservabilityAIAssistantClient } from './client'; + +export type RespondFunctionResources = Pick< + ObservabilityAIAssistantRouteHandlerResources, + 'context' | 'logger' | 'plugins' | 'request' +>; + +type RespondFunction = ( + options: { + arguments: TArguments; + messages: Message[]; + connectorId: string; + }, + signal: AbortSignal +) => Promise; + +export interface FunctionHandler { + definition: FunctionDefinition; + respond: RespondFunction; +} + +export type RegisterFunction = < + TParameters extends CompatibleJSONSchema = any, + TResponse extends FunctionResponse = any, + TArguments = FromSchema +>( + definition: FunctionDefinition, + respond: RespondFunction +) => void; +export type FunctionHandlerRegistry = Map; + +export type ChatRegistrationFunction = ({}: { + signal: AbortSignal; + resources: RespondFunctionResources; + client: ObservabilityAIAssistantClient; + registerFunction: RegisterFunction; + registerContext: RegisterContextDefinition; +}) => Promise; + export interface ObservabilityAIAssistantResourceNames { componentTemplate: { conversations: string; diff --git a/x-pack/plugins/observability_ai_assistant/server/service/util/stream_into_observable.ts b/x-pack/plugins/observability_ai_assistant/server/service/util/stream_into_observable.ts new file mode 100644 index 00000000000000..5ccca849e9b36f --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/server/service/util/stream_into_observable.ts @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { concatMap, filter, from, map, Observable } from 'rxjs'; +import type { Readable } from 'stream'; + +export function streamIntoObservable(readable: Readable): Observable { + let lineBuffer = ''; + return from(readable).pipe( + map((chunk: Buffer) => chunk.toString('utf-8')), + map((part) => { + const lines = (lineBuffer + part).split('\n'); + lineBuffer = lines.pop() || ''; // Keep the last incomplete line for the next chunk + return lines; + }), + concatMap((lines) => lines), + filter((line) => line.trim() !== '') + ); +} diff --git a/x-pack/plugins/observability_ai_assistant/server/types.ts b/x-pack/plugins/observability_ai_assistant/server/types.ts index bdb283b9a1df2b..fe20daf90170aa 100644 --- a/x-pack/plugins/observability_ai_assistant/server/types.ts +++ b/x-pack/plugins/observability_ai_assistant/server/types.ts @@ -18,10 +18,13 @@ import type { TaskManagerStartContract, } from '@kbn/task-manager-plugin/server'; import { DataViewsServerPluginStart } from '@kbn/data-views-plugin/server'; +import { ObservabilityAIAssistantService } from './service'; /* eslint-disable @typescript-eslint/no-empty-interface*/ export interface ObservabilityAIAssistantPluginStart {} -export interface ObservabilityAIAssistantPluginSetup {} +export interface ObservabilityAIAssistantPluginSetup { + service: ObservabilityAIAssistantService; +} export interface ObservabilityAIAssistantPluginSetupDependencies { actions: ActionsPluginSetup; security: SecurityPluginSetup; From c2b91a3d8781c8ebb82cd214393a92827d1cbd07 Mon Sep 17 00:00:00 2001 From: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Date: Tue, 5 Dec 2023 15:51:04 +0000 Subject: [PATCH 09/15] [CI] Auto-commit changed files from 'node scripts/lint_ts_projects --fix' --- x-pack/plugins/apm/tsconfig.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugins/apm/tsconfig.json b/x-pack/plugins/apm/tsconfig.json index db829dc3ed5f87..ba23f19f78c4a3 100644 --- a/x-pack/plugins/apm/tsconfig.json +++ b/x-pack/plugins/apm/tsconfig.json @@ -106,7 +106,8 @@ "@kbn/custom-icons", "@kbn/elastic-agent-utils", "@kbn/shared-ux-link-redirect-app", - "@kbn/observability-get-padded-alert-time-range-util" + "@kbn/observability-get-padded-alert-time-range-util", + "@kbn/core-lifecycle-server" ], "exclude": ["target/**/*"] } From c3efc37c2fae162a4923692e97ab11d57b637485 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Tue, 5 Dec 2023 19:55:21 +0100 Subject: [PATCH 10/15] Reinstate Kibana function --- .../server/functions/index.ts | 2 + .../server/functions/kibana.ts | 84 +++++++++++++++++++ .../server/service/client/index.ts | 39 +-------- .../server/service/index.ts | 6 +- 4 files changed, 90 insertions(+), 41 deletions(-) create mode 100644 x-pack/plugins/observability_ai_assistant/server/functions/kibana.ts diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/index.ts b/x-pack/plugins/observability_ai_assistant/server/functions/index.ts index db673a725f2b85..d4e5219840c149 100644 --- a/x-pack/plugins/observability_ai_assistant/server/functions/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/index.ts @@ -14,6 +14,7 @@ import { registerElasticsearchFunction } from './elasticsearch'; import { registerEsqlFunction } from './esql'; import { registerGetDatasetInfoFunction } from './get_dataset_info'; import { registerLensFunction } from './lens'; +import { registerKibanaFunction } from './kibana'; export type FunctionRegistrationParameters = Omit< Parameters[0], @@ -74,6 +75,7 @@ export const registerFunctions: ChatRegistrationFunction = async ({ } registerElasticsearchFunction(registrationParameters); + registerKibanaFunction(registrationParameters); registerEsqlFunction(registrationParameters); registerAlertsFunction(registrationParameters); registerGetDatasetInfoFunction(registrationParameters); diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/kibana.ts b/x-pack/plugins/observability_ai_assistant/server/functions/kibana.ts new file mode 100644 index 00000000000000..49516e28c38e81 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/server/functions/kibana.ts @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import axios from 'axios'; +import { format } from 'url'; +import type { FunctionRegistrationParameters } from '.'; + +export function registerKibanaFunction({ + registerFunction, + resources, +}: FunctionRegistrationParameters) { + registerFunction( + { + name: 'kibana', + contexts: ['core'], + description: + 'Call Kibana APIs on behalf of the user. Only call this function when the user has explicitly requested it, and you know how to call it, for example by querying the knowledge base or having the user explain it to you. Assume that pathnames, bodies and query parameters may have changed since your knowledge cut off date.', + descriptionForUser: 'Call Kibana APIs on behalf of the user', + parameters: { + type: 'object', + additionalProperties: false, + properties: { + method: { + type: 'string', + description: 'The HTTP method of the Kibana endpoint', + enum: ['GET', 'PUT', 'POST', 'DELETE', 'PATCH'] as const, + }, + pathname: { + type: 'string', + description: 'The pathname of the Kibana endpoint, excluding query parameters', + }, + query: { + type: 'object', + description: 'The query parameters, as an object', + additionalProperties: { + type: 'string', + }, + }, + body: { + type: 'object', + description: 'The body of the request', + }, + }, + required: ['method', 'pathname'] as const, + }, + }, + ({ arguments: { method, pathname, body, query } }, signal) => { + const { request } = resources; + + const { + protocol, + host, + username, + password, + pathname: pathnameFromRequest, + } = request.rewrittenUrl!; + const nextUrl = { + host, + protocol, + username, + password, + pathname: pathnameFromRequest.replace( + '/internal/observability_ai_assistant/chat/complete', + pathname + ), + query, + }; + + return axios({ + method, + headers: request.headers, + url: format(nextUrl), + data: body ? JSON.stringify(body) : undefined, + signal, + }).then((response) => { + return { content: response.data }; + }); + } + ); +} diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts index c64badcece9dc4..ad4787b610f197 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts @@ -299,7 +299,7 @@ export class ObservabilityAIAssistantClient { } const updatedConversation = await this.update( - merge({}, conversation._source, { messages: nextMessages }) + merge({}, omit(conversation._source, 'messages'), { messages: nextMessages }) ); await write({ type: StreamingChatResponseEventType.ConversationUpdate, @@ -392,42 +392,7 @@ export class ObservabilityAIAssistantClient { }) ); - // add recalled information to system message, so the LLM considers it more important - - const recallMessages = messagesForOpenAI.filter((message) => message.name === 'recall'); - - const recalledDocuments: Map = new Map(); - - recallMessages.forEach((message) => { - const entries = message.content - ? (JSON.parse(message.content) as Array<{ id: string; text: string }>) - : []; - - const ids: string[] = []; - - entries.forEach((entry) => { - const id = entry.id; - if (!recalledDocuments.has(id)) { - recalledDocuments.set(id, entry); - } - ids.push(id); - }); - - message.content = `The following documents, present in the system message, were recalled: ${ids.join( - ', ' - )}`; - }); - - const systemMessage = messagesForOpenAI.find((message) => message.role === MessageRole.System); - - if (systemMessage && recalledDocuments.size > 0) { - systemMessage.content += `The "recall" function is not available. Do not attempt to execute it. Recalled documents: ${JSON.stringify( - Array.from(recalledDocuments.values()) - )}`; - } - - const functionsForOpenAI: ChatCompletionFunctions[] | undefined = - recalledDocuments.size > 0 ? functions?.filter((fn) => fn.name !== 'recall') : functions; + const functionsForOpenAI: ChatCompletionFunctions[] | undefined = functions; const request: Omit & { model?: string } = { messages: messagesForOpenAI, diff --git a/x-pack/plugins/observability_ai_assistant/server/service/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/index.ts index 0b085f7b735a3d..60b9e99cabb5a0 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/index.ts @@ -305,10 +305,8 @@ export class ObservabilityAIAssistantService { definition.name, // new Validator(definition.parameters as Schema, '2020-12', true) { - validate: () => { - return { valid: true, errors: [] }; - }, - } as unknown as IValidator + validate: () => ({ valid: true, errors: [] }), + } as any as IValidator ); functionHandlerRegistry.set(definition.name, { definition, respond }); }; From 4062dd59b250180b74bc90ec9a92b426ebf3a89f Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Tue, 5 Dec 2023 20:26:44 +0100 Subject: [PATCH 11/15] Remove unused translations --- x-pack/plugins/translations/translations/fr-FR.json | 1 - x-pack/plugins/translations/translations/ja-JP.json | 1 - x-pack/plugins/translations/translations/zh-CN.json | 1 - 3 files changed, 3 deletions(-) diff --git a/x-pack/plugins/translations/translations/fr-FR.json b/x-pack/plugins/translations/translations/fr-FR.json index ae552e749fd781..708b58ca433ce4 100644 --- a/x-pack/plugins/translations/translations/fr-FR.json +++ b/x-pack/plugins/translations/translations/fr-FR.json @@ -29686,7 +29686,6 @@ "xpack.observabilityAiAssistant.conversationStartTitle": "a démarré une conversation", "xpack.observabilityAiAssistant.couldNotFindConversationTitle": "Conversation introuvable", "xpack.observabilityAiAssistant.emptyConversationTitle": "Nouvelle conversation", - "xpack.observabilityAiAssistant.errorCreatingConversation": "Impossible de créer une conversation", "xpack.observabilityAiAssistant.errorSettingUpKnowledgeBase": "Impossible de configurer la base de connaissances", "xpack.observabilityAiAssistant.errorUpdatingConversation": "Impossible de mettre à jour la conversation", "xpack.observabilityAiAssistant.experimentalFunctionBanner.feedbackButton": "Donner un retour", diff --git a/x-pack/plugins/translations/translations/ja-JP.json b/x-pack/plugins/translations/translations/ja-JP.json index 120f39c1ba15c1..805688aec939b1 100644 --- a/x-pack/plugins/translations/translations/ja-JP.json +++ b/x-pack/plugins/translations/translations/ja-JP.json @@ -29686,7 +29686,6 @@ "xpack.observabilityAiAssistant.conversationStartTitle": "会話を開始しました", "xpack.observabilityAiAssistant.couldNotFindConversationTitle": "会話が見つかりません", "xpack.observabilityAiAssistant.emptyConversationTitle": "新しい会話", - "xpack.observabilityAiAssistant.errorCreatingConversation": "会話を作成できませんでした", "xpack.observabilityAiAssistant.errorSettingUpKnowledgeBase": "ナレッジベースをセットアップできませんでした", "xpack.observabilityAiAssistant.errorUpdatingConversation": "会話を更新できませんでした", "xpack.observabilityAiAssistant.experimentalFunctionBanner.feedbackButton": "フィードバックを作成する", diff --git a/x-pack/plugins/translations/translations/zh-CN.json b/x-pack/plugins/translations/translations/zh-CN.json index 99a78782a0c6c3..e1d21d4f0a7e0e 100644 --- a/x-pack/plugins/translations/translations/zh-CN.json +++ b/x-pack/plugins/translations/translations/zh-CN.json @@ -29683,7 +29683,6 @@ "xpack.observabilityAiAssistant.conversationStartTitle": "已开始对话", "xpack.observabilityAiAssistant.couldNotFindConversationTitle": "未找到对话", "xpack.observabilityAiAssistant.emptyConversationTitle": "新对话", - "xpack.observabilityAiAssistant.errorCreatingConversation": "无法创建对话", "xpack.observabilityAiAssistant.errorSettingUpKnowledgeBase": "无法设置知识库", "xpack.observabilityAiAssistant.errorUpdatingConversation": "无法更新对话", "xpack.observabilityAiAssistant.experimentalFunctionBanner.feedbackButton": "反馈", From 7a3b46151e8ea32be93063cc049c4d5073c136cc Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Wed, 6 Dec 2023 20:43:37 +0100 Subject: [PATCH 12/15] Review feedback and API tests --- .../get_apm_downstream_dependencies.ts | 2 +- x-pack/plugins/apm/server/plugin.ts | 2 +- .../common/conversation_complete.ts | 4 +- .../public/functions/lens.tsx | 2 +- .../public/service/create_chat_service.ts | 2 +- .../public/utils/builders.ts | 7 +- .../public/utils/storybook_decorator.tsx | 2 +- .../server/functions/get_dataset_info.ts | 2 +- .../server/functions/index.ts | 1 + .../server/plugin.ts | 2 +- .../server/service/client/index.test.ts | 19 +- .../server/service/client/index.ts | 41 +-- .../server/service/index.ts | 2 +- .../common/create_llm_proxy.ts | 155 ++++++++++++ .../common/create_openai_chunk.ts | 27 ++ .../tests/complete/complete.spec.ts | 234 ++++++++++++++++++ .../tests/functions/elasticsearch.spec.ts | 46 ---- 17 files changed, 460 insertions(+), 90 deletions(-) create mode 100644 x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts create mode 100644 x-pack/test/observability_ai_assistant_api_integration/common/create_openai_chunk.ts create mode 100644 x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts delete mode 100644 x-pack/test/observability_ai_assistant_api_integration/tests/functions/elasticsearch.spec.ts diff --git a/x-pack/plugins/apm/server/assistant_functions/get_apm_downstream_dependencies.ts b/x-pack/plugins/apm/server/assistant_functions/get_apm_downstream_dependencies.ts index 45c1b876974aac..0440f684eeedd7 100644 --- a/x-pack/plugins/apm/server/assistant_functions/get_apm_downstream_dependencies.ts +++ b/x-pack/plugins/apm/server/assistant_functions/get_apm_downstream_dependencies.ts @@ -6,7 +6,7 @@ */ import { i18n } from '@kbn/i18n'; -import { FunctionRegistrationParameters } from '.'; +import type { FunctionRegistrationParameters } from '.'; import { getAssistantDownstreamDependencies } from '../routes/assistant_functions/get_apm_downstream_dependencies'; export function registerGetApmDownstreamDependenciesFunction({ diff --git a/x-pack/plugins/apm/server/plugin.ts b/x-pack/plugins/apm/server/plugin.ts index 9e1454a3e54943..220bdebbf32d70 100644 --- a/x-pack/plugins/apm/server/plugin.ts +++ b/x-pack/plugins/apm/server/plugin.ts @@ -233,7 +233,7 @@ export class APMPlugin this.logger?.error(e); }); - plugins.observabilityAIAssistant.service.registration( + plugins.observabilityAIAssistant.service.register( registerAssistantFunctions({ config: this.currentConfig!, coreSetup: core, diff --git a/x-pack/plugins/observability_ai_assistant/common/conversation_complete.ts b/x-pack/plugins/observability_ai_assistant/common/conversation_complete.ts index 4002eec9d10d78..f7e513efe3d017 100644 --- a/x-pack/plugins/observability_ai_assistant/common/conversation_complete.ts +++ b/x-pack/plugins/observability_ai_assistant/common/conversation_complete.ts @@ -59,12 +59,12 @@ export type ConversationUpdateEvent = StreamingChatResponseEventBase< } >; -type MessageAddEvent = StreamingChatResponseEventBase< +export type MessageAddEvent = StreamingChatResponseEventBase< StreamingChatResponseEventType.MessageAdd, { message: Message; id: string } >; -type ConversationCompletionErrorEvent = StreamingChatResponseEventBase< +export type ConversationCompletionErrorEvent = StreamingChatResponseEventBase< StreamingChatResponseEventType.ConversationCompletionError, { error: { message: string; stack?: string; code?: ChatCompletionErrorCode } } >; diff --git a/x-pack/plugins/observability_ai_assistant/public/functions/lens.tsx b/x-pack/plugins/observability_ai_assistant/public/functions/lens.tsx index 0d694a532a8056..22d4b91a5f9063 100644 --- a/x-pack/plugins/observability_ai_assistant/public/functions/lens.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/functions/lens.tsx @@ -12,7 +12,7 @@ import type { LensEmbeddableInput, LensPublicStart } from '@kbn/lens-plugin/publ import React, { useState } from 'react'; import useAsync from 'react-use/lib/useAsync'; import { Assign } from 'utility-types'; -import { LensFunctionArguments } from '../../common/functions/lens'; +import type { LensFunctionArguments } from '../../common/functions/lens'; import type { ObservabilityAIAssistantPluginStartDependencies, ObservabilityAIAssistantService, diff --git a/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts b/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts index de2fe831992f4d..5de057934ad9ee 100644 --- a/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts +++ b/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts @@ -26,7 +26,7 @@ import { import { ChatCompletionErrorCode, ConversationCompletionError, - StreamingChatResponseEvent, + type StreamingChatResponseEvent, StreamingChatResponseEventType, } from '../../common/conversation_complete'; import { diff --git a/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts b/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts index f54995d3d17cd4..6ba40e8bc692dc 100644 --- a/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts +++ b/x-pack/plugins/observability_ai_assistant/public/utils/builders.ts @@ -7,7 +7,12 @@ import { merge, uniqueId } from 'lodash'; import { DeepPartial } from 'utility-types'; -import { Conversation, FunctionDefinition, Message, MessageRole } from '../../common/types'; +import { + type Conversation, + type FunctionDefinition, + type Message, + MessageRole, +} from '../../common/types'; import { getAssistantSetupMessage } from '../service/get_assistant_setup_message'; type BuildMessageProps = DeepPartial & { diff --git a/x-pack/plugins/observability_ai_assistant/public/utils/storybook_decorator.tsx b/x-pack/plugins/observability_ai_assistant/public/utils/storybook_decorator.tsx index 88207439052ba0..04a30ce53059d3 100644 --- a/x-pack/plugins/observability_ai_assistant/public/utils/storybook_decorator.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/utils/storybook_decorator.tsx @@ -10,7 +10,7 @@ import type { AuthenticatedUser } from '@kbn/security-plugin/common'; import type { SharePluginStart } from '@kbn/share-plugin/public'; import React, { ComponentType } from 'react'; import { Observable } from 'rxjs'; -import { StreamingChatResponseEvent } from '../../common/conversation_complete'; +import type { StreamingChatResponseEvent } from '../../common/conversation_complete'; import { ObservabilityAIAssistantAPIClient } from '../api'; import { ObservabilityAIAssistantChatServiceProvider } from '../context/observability_ai_assistant_chat_service_provider'; import { ObservabilityAIAssistantProvider } from '../context/observability_ai_assistant_provider'; diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts b/x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts index df9cd0cd231d62..bd48e1bda2f05f 100644 --- a/x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts @@ -169,7 +169,7 @@ export function registerGetDatasetInfoFunction({ fields: string[]; } ).fields - .filter((field) => fieldNames.includes(field)) + .filter((field) => fieldsInChunk.includes(field)) .map((field) => { const fieldDescriptors = groupedFields[field]; return `${field}:${fieldDescriptors diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/index.ts b/x-pack/plugins/observability_ai_assistant/server/functions/index.ts index d4e5219840c149..6300852cbc0649 100644 --- a/x-pack/plugins/observability_ai_assistant/server/functions/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/index.ts @@ -67,6 +67,7 @@ export const registerFunctions: ChatRegistrationFunction = async ({ description += `Here are principles you MUST adhere to, in order: - DO NOT make any assumptions about where and how users have stored their data. ALWAYS first call get_dataset_info function with empty string to get information about available indices. Once you know about available indices you MUST use this function again to get a list of available fields for specific index. If user provides an index name make sure its a valid index first before using it to retrieve the field list by calling this function with an empty string! `; + registerSummarizationFunction(registrationParameters); registerRecallFunction(registrationParameters); registerLensFunction(registrationParameters); diff --git a/x-pack/plugins/observability_ai_assistant/server/plugin.ts b/x-pack/plugins/observability_ai_assistant/server/plugin.ts index 2b77cb9c5a6eeb..577edfeb1da0c3 100644 --- a/x-pack/plugins/observability_ai_assistant/server/plugin.ts +++ b/x-pack/plugins/observability_ai_assistant/server/plugin.ts @@ -111,7 +111,7 @@ export class ObservabilityAIAssistantPlugin taskManager: plugins.taskManager, })); - service.registration(registerFunctions); + service.register(registerFunctions); addLensDocsToKb({ service, logger: this.logger.get('kb').get('lens') }); diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts index 8769945922d516..a0eeaf50e5ebdf 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts @@ -161,21 +161,13 @@ describe('Observability AI Assistant service', () => { .mockImplementationOnce(() => { return new Promise((resolve, reject) => { titleLlmPromiseResolve = (title: string) => { - const response = { - object: 'chat.completion', - choices: [ - { - message: { - role: MessageRole.Assistant, - content: title, - }, - }, - ], - }; + const titleLlmSimulator = createLlmSimulator(); + titleLlmSimulator.next({ content: title }); + titleLlmSimulator.complete(); resolve({ actionId: '', status: 'ok', - data: response, + data: titleLlmSimulator.stream, }); }; titleLlmPromiseReject = reject; @@ -222,9 +214,10 @@ describe('Observability AI Assistant service', () => { { actionId: 'foo', params: { - subAction: 'run', + subAction: 'stream', subActionParams: { body: expect.any(String), + stream: true, }, }, }, diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts index 5bc76e3cae3a56..c12ed8d7f42480 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts @@ -12,12 +12,11 @@ import type { Logger } from '@kbn/logging'; import type { PublicMethodsOf } from '@kbn/utility-types'; import { compact, isEmpty, last, merge, omit, pick } from 'lodash'; import type { - ChatCompletionFunctions, ChatCompletionRequestMessage, CreateChatCompletionRequest, CreateChatCompletionResponse, } from 'openai'; -import { isObservable } from 'rxjs'; +import { isObservable, lastValueFrom } from 'rxjs'; import { PassThrough, Readable } from 'stream'; import { v4 } from 'uuid'; import { @@ -36,6 +35,7 @@ import { type KnowledgeBaseEntry, type Message, } from '../../../common/types'; +import { concatenateOpenAiChunks } from '../../../common/utils/concatenate_openai_chunks'; import { processOpenAiStream } from '../../../common/utils/process_openai_stream'; import type { ChatFunctionClient } from '../chat_function_client'; import { @@ -396,7 +396,7 @@ export class ObservabilityAIAssistantClient { }) ); - const functionsForOpenAI: ChatCompletionFunctions[] | undefined = functions; + const functionsForOpenAI = functions; const request: Omit & { model?: string } = { messages: messagesForOpenAI, @@ -491,12 +491,12 @@ export class ObservabilityAIAssistantClient { connectorId: string; signal: AbortSignal; }) => { - const response = await this.chat({ + const stream = await this.chat({ messages: [ { '@timestamp': new Date().toISOString(), message: { - role: MessageRole.Assistant, + role: MessageRole.User, content: messages.slice(1).reduce((acc, curr) => { return `${acc} ${curr.message.role}: ${curr.message.content}`; }, 'You are a helpful assistant for Elastic Observability. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Here is the content:'), @@ -504,24 +504,25 @@ export class ObservabilityAIAssistantClient { }, ], connectorId, - stream: false, + stream: true, signal, }); - if ('object' in response && response.object === 'chat.completion') { - const input = response.choices[0].message?.content || ''; - - // This regular expression captures a string enclosed in single or double quotes. - // It extracts the string content without the quotes. - // Example matches: - // - "Hello, World!" => Captures: Hello, World! - // - 'Another Example' => Captures: Another Example - // - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes - const match = input.match(/^["']?([^"']+)["']?$/); - const title = match ? match[1] : input; - return title; - } - return undefined; + const response = await lastValueFrom( + streamIntoObservable(stream).pipe(processOpenAiStream(), concatenateOpenAiChunks()) + ); + + const input = response.message?.content || ''; + + // This regular expression captures a string enclosed in single or double quotes. + // It extracts the string content without the quotes. + // Example matches: + // - "Hello, World!" => Captures: Hello, World! + // - 'Another Example' => Captures: Another Example + // - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes + const match = input.match(/^["']?([^"']+)["']?$/); + const title = match ? match[1] : input; + return title; }; setTitle = async ({ conversationId, title }: { conversationId: string; title: string }) => { diff --git a/x-pack/plugins/observability_ai_assistant/server/service/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/index.ts index b12c358f2361d0..e42a4739206449 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/index.ts @@ -377,7 +377,7 @@ export class ObservabilityAIAssistantService { ); } - registration(fn: ChatRegistrationFunction) { + register(fn: ChatRegistrationFunction) { this.registrations.push(fn); } } diff --git a/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts b/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts new file mode 100644 index 00000000000000..9bcdde1c75524f --- /dev/null +++ b/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts @@ -0,0 +1,155 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import getPort from 'get-port'; +import http, { type Server } from 'http'; +import { once, pull } from 'lodash'; +import { createOpenAiChunk } from './create_openai_chunk'; + +type RequestHandler = ( + request: http.IncomingMessage, + response: http.ServerResponse & { req: http.IncomingMessage } +) => void; + +type RequestFilterFunction = ({}: { + request: http.IncomingMessage; + data: string; +}) => Promise; + +export interface LlmResponseSimulator { + status: (code: number) => Promise; + next: ( + msg: + | string + | { + content?: string; + function_call?: { name: string; arguments: string }; + } + ) => Promise; + error: (error: Error) => Promise; + complete: () => Promise; + write: (chunk: string) => Promise; +} + +export class LlmProxy { + server: Server; + + requestHandlers: Array<{ + filter?: RequestFilterFunction; + handler: RequestHandler; + }> = []; + + constructor(private readonly port: number) { + this.server = http + .createServer(async (request, response) => { + const handlers = this.requestHandlers.concat(); + for (let i = 0; i < handlers.length; i++) { + const handler = handlers[i]; + let data: string = ''; + await new Promise((resolve, reject) => { + request.on('data', (chunk) => { + data += chunk.toString(); + }); + request.on('close', () => { + resolve(); + }); + }); + if (!handler.filter || (await handler.filter({ data, request }))) { + pull(this.requestHandlers, handler); + handler.handler(request, response); + return; + } + } + }) + .listen(port); + } + + getPort() { + return this.port; + } + + close() { + this.server.close(); + } + + respond( + cb: (simulator: LlmResponseSimulator) => Promise, + filter?: RequestFilterFunction + ): Promise { + return Promise.race([ + new Promise((outerPromiseResolve, outerPromiseReject) => { + const requestHandlerPromise = new Promise>((resolve) => { + this.requestHandlers.push({ + filter, + handler: (request, response) => { + resolve([request, response]); + }, + }); + }); + + function write(chunk: string) { + return withResponse( + (response) => new Promise((resolve) => response.write(chunk, () => resolve())) + ); + } + function end() { + return withResponse((response) => { + return new Promise((resolve) => response.end(resolve)); + }); + } + + function withResponse(responseCb: (response: Parameters[1]) => void) { + return requestHandlerPromise.then(([request, response]) => { + return responseCb(response); + }); + } + + cb({ + status: once((status: number) => { + return withResponse((response) => { + response.writeHead(status, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: 'keep-alive', + }); + }); + }), + next: (msg) => { + const chunk = createOpenAiChunk(msg); + return write(`data: ${JSON.stringify(chunk)}\n`); + }, + write: (chunk: string) => { + return write(chunk); + }, + complete: async () => { + await write('data: [DONE]'); + await end(); + }, + error: async (error) => { + await write(`data: ${JSON.stringify({ error })}`); + await end(); + }, + }) + .then((result) => { + outerPromiseResolve(result); + }) + .catch((err) => { + outerPromiseReject(err); + }); + }), + new Promise((_, reject) => { + setTimeout(() => reject(new Error('Operation timed out')), 5000); + }), + ]); + } +} + +export async function createLlmProxy() { + const port = await getPort({ port: getPort.makeRange(9000, 9100) }); + + return new LlmProxy(port); +} diff --git a/x-pack/test/observability_ai_assistant_api_integration/common/create_openai_chunk.ts b/x-pack/test/observability_ai_assistant_api_integration/common/create_openai_chunk.ts new file mode 100644 index 00000000000000..7e39a7d73ce8b8 --- /dev/null +++ b/x-pack/test/observability_ai_assistant_api_integration/common/create_openai_chunk.ts @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { CreateChatCompletionResponseChunk } from '@kbn/observability-ai-assistant-plugin/common/types'; +import { v4 } from 'uuid'; + +export function createOpenAiChunk( + msg: string | { content?: string; function_call?: { name: string; arguments?: string } } +): CreateChatCompletionResponseChunk { + msg = typeof msg === 'string' ? { content: msg } : msg; + + return { + id: v4(), + object: 'chat.completion.chunk', + created: 0, + model: 'gpt-4', + choices: [ + { + delta: msg, + }, + ], + }; +} diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts new file mode 100644 index 00000000000000..d15d6c84bc1f8c --- /dev/null +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts @@ -0,0 +1,234 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { Response } from 'supertest'; +import { MessageRole, type Message } from '@kbn/observability-ai-assistant-plugin/common'; +import { omit } from 'lodash'; +import { PassThrough } from 'stream'; +import expect from '@kbn/expect'; +import { + StreamingChatResponseEvent, + StreamingChatResponseEventType, +} from '@kbn/observability-ai-assistant-plugin/common/conversation_complete'; +import { CreateChatCompletionRequest } from 'openai'; +import { createLlmProxy, LlmProxy } from '../../common/create_llm_proxy'; +import { createOpenAiChunk } from '../../common/create_openai_chunk'; +import { FtrProviderContext } from '../../common/ftr_provider_context'; + +export default function ApiTest({ getService }: FtrProviderContext) { + const supertest = getService('supertest'); + + const COMPLETE_API_URL = `/internal/observability_ai_assistant/chat/complete`; + + const messages: Message[] = [ + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.System, + content: 'You are a helpful assistant', + }, + }, + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + content: 'Good morning!', + }, + }, + ]; + + describe('Complete', () => { + let proxy: LlmProxy; + let connectorId: string; + + before(async () => { + proxy = await createLlmProxy(); + + const response = await supertest + .post('/api/actions/connector') + .set('kbn-xsrf', 'foo') + .send({ + name: 'OpenAI', + connector_type_id: '.gen-ai', + config: { + apiProvider: 'OpenAI', + apiUrl: `http://localhost:${proxy.getPort()}`, + }, + secrets: { + apiKey: 'my-api-key', + }, + }) + .expect(200); + + connectorId = response.body.id; + }); + + after(async () => { + await supertest + .delete(`/api/actions/connector/${connectorId}`) + .set('kbn-xsrf', 'foo') + .expect(204); + + proxy.close(); + }); + + it.skip('returns a streaming response from the server', async () => { + await proxy.respond(async (simulator) => { + const receivedChunks: any[] = []; + + const passThrough = new PassThrough(); + + supertest + .post(COMPLETE_API_URL) + .set('kbn-xsrf', 'foo') + .send({ + messages, + connectorId, + persist: false, + }) + .pipe(passThrough); + + passThrough.on('data', (chunk) => { + receivedChunks.push(chunk.toString()); + }); + + await simulator.status(200); + const chunk = JSON.stringify(createOpenAiChunk('Hello')); + + await simulator.write(`data: ${chunk.substring(0, 10)}`); + await simulator.write(`${chunk.substring(10)}\n`); + await simulator.complete(); + + await new Promise((resolve) => passThrough.on('end', () => resolve())); + + const parsedChunks = receivedChunks + .join('') + .split('\n') + .map((line) => line.trim()) + .filter(Boolean) + .map((line) => JSON.parse(line) as StreamingChatResponseEvent); + + expect(parsedChunks.length).to.be(2); + expect(omit(parsedChunks[0], 'id')).to.eql({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + message: { + content: 'Hello', + }, + }); + + expect(omit(parsedChunks[1], 'id', 'message.@timestamp')).to.eql({ + type: StreamingChatResponseEventType.MessageAdd, + message: { + message: { + content: 'Hello', + role: MessageRole.Assistant, + function_call: { + name: '', + arguments: '', + trigger: MessageRole.Assistant, + }, + }, + }, + }); + }); + }); + + it('creates a new conversation', async () => { + const lines = await proxy.respond( + async (titleSimulator) => { + return await proxy.respond( + async (chatSimulator) => { + const responsePromise = new Promise((resolve, reject) => { + supertest + .post(COMPLETE_API_URL) + .set('kbn-xsrf', 'foo') + .send({ + messages, + connectorId, + persist: true, + }) + .end((err, response) => { + if (err) { + return reject(err); + } + return resolve(response); + }); + }); + + await titleSimulator.status(200); + await titleSimulator.next('My generated title'); + await titleSimulator.complete(); + + await chatSimulator.status(200); + await chatSimulator.next('Hello'); + await chatSimulator.next(' again'); + await chatSimulator.complete(); + + const response = await responsePromise; + + return String(response.body) + .split('\n') + .map((line) => line.trim()) + .filter(Boolean) + .map((line) => JSON.parse(line) as StreamingChatResponseEvent); + }, + async ({ data }) => { + return new Promise((resolve) => { + const body: CreateChatCompletionRequest = JSON.parse(data); + resolve(body.messages.length !== 1); + }); + } + ); + }, + async ({ data }) => { + return new Promise((resolve) => { + const body: CreateChatCompletionRequest = JSON.parse(data); + resolve(body.messages.length === 1); + }); + } + ); + + expect(omit(lines[0], 'id')).to.eql({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + message: { + content: 'Hello', + }, + }); + expect(omit(lines[1], 'id')).to.eql({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + message: { + content: ' again', + }, + }); + expect(omit(lines[2], 'id', 'message.@timestamp')).to.eql({ + type: StreamingChatResponseEventType.MessageAdd, + message: { + message: { + content: 'Hello again', + function_call: { + arguments: '', + name: '', + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }, + }); + expect(omit(lines[3], 'conversation.id', 'conversation.last_updated')).to.eql({ + type: StreamingChatResponseEventType.ConversationCreate, + conversation: { + title: 'My generated title', + }, + }); + }); + + // todo + it.skip('updates an existing conversation', async () => {}); + + // todo + it.skip('executes a function', async () => {}); + }); +} diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/functions/elasticsearch.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/functions/elasticsearch.spec.ts deleted file mode 100644 index 919085369cc704..00000000000000 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/functions/elasticsearch.spec.ts +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import expect from '@kbn/expect'; -import type { FtrProviderContext } from '../../common/ftr_provider_context'; - -export default function ApiTest({ getService }: FtrProviderContext) { - const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantAPIClient'); - - describe('Functions: elasticsearch', () => { - it('executes a search request', async () => { - const response = await observabilityAIAssistantAPIClient - .readUser({ - endpoint: 'POST /internal/observability_ai_assistant/functions/elasticsearch', - params: { - body: { - method: 'GET', - path: '_all/_search', - body: { - query: { - bool: { - filter: [ - { - term: { - matches_no_docs: 'true', - }, - }, - ], - }, - }, - track_total_hits: false, - }, - }, - }, - }) - .expect(200); - - expect((response.body as any).hits.hits).to.eql([]); - expect((response.body as any).hits.total).to.eql(undefined); - }); - }); -} From 31f6dab2dfba104c87c73139d9065fdaa3d93a90 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Thu, 7 Dec 2023 11:00:42 +0100 Subject: [PATCH 13/15] Fix stability of API tests --- .../common/create_llm_proxy.ts | 168 ++++++------ .../tests/complete/complete.spec.ts | 254 ++++++++++-------- 2 files changed, 218 insertions(+), 204 deletions(-) diff --git a/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts b/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts index 9bcdde1c75524f..51beb1b3671a5f 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/common/create_llm_proxy.ts @@ -10,15 +10,15 @@ import http, { type Server } from 'http'; import { once, pull } from 'lodash'; import { createOpenAiChunk } from './create_openai_chunk'; -type RequestHandler = ( - request: http.IncomingMessage, - response: http.ServerResponse & { req: http.IncomingMessage } -) => void; +type Request = http.IncomingMessage; +type Response = http.ServerResponse & { req: http.IncomingMessage }; -type RequestFilterFunction = ({}: { - request: http.IncomingMessage; - data: string; -}) => Promise; +type RequestHandler = (request: Request, response: Response) => void; + +interface RequestInterceptor { + name: string; + when: (body: string) => boolean; +} export interface LlmResponseSimulator { status: (code: number) => Promise; @@ -38,32 +38,33 @@ export interface LlmResponseSimulator { export class LlmProxy { server: Server; - requestHandlers: Array<{ - filter?: RequestFilterFunction; - handler: RequestHandler; - }> = []; + interceptors: Array = []; constructor(private readonly port: number) { this.server = http .createServer(async (request, response) => { - const handlers = this.requestHandlers.concat(); - for (let i = 0; i < handlers.length; i++) { - const handler = handlers[i]; - let data: string = ''; - await new Promise((resolve, reject) => { - request.on('data', (chunk) => { - data += chunk.toString(); - }); - request.on('close', () => { - resolve(); - }); + const interceptors = this.interceptors.concat(); + + const body = await new Promise((resolve, reject) => { + let concatenated = ''; + request.on('data', (chunk) => { + concatenated += chunk.toString(); }); - if (!handler.filter || (await handler.filter({ data, request }))) { - pull(this.requestHandlers, handler); - handler.handler(request, response); + request.on('close', () => { + resolve(concatenated); + }); + }); + + while (interceptors.length) { + const interceptor = interceptors.shift()!; + if (interceptor.when(body)) { + pull(this.interceptors, interceptor); + interceptor.handle(request, response); return; } } + + throw new Error('No interceptors found to handle request'); }) .listen(port); } @@ -72,79 +73,70 @@ export class LlmProxy { return this.port; } + clear() { + this.interceptors.length = 0; + } + close() { this.server.close(); } - respond( - cb: (simulator: LlmResponseSimulator) => Promise, - filter?: RequestFilterFunction - ): Promise { - return Promise.race([ - new Promise((outerPromiseResolve, outerPromiseReject) => { - const requestHandlerPromise = new Promise>((resolve) => { - this.requestHandlers.push({ - filter, - handler: (request, response) => { - resolve([request, response]); - }, - }); - }); + intercept( + name: string, + when: RequestInterceptor['when'] + ): { + waitForIntercept: () => Promise; + } { + const waitForInterceptPromise = Promise.race([ + new Promise((outerResolve, outerReject) => { + this.interceptors.push({ + name, + when, + handle: (request, response) => { + function write(chunk: string) { + return new Promise((resolve) => response.write(chunk, () => resolve())); + } + function end() { + return new Promise((resolve) => response.end(resolve)); + } - function write(chunk: string) { - return withResponse( - (response) => new Promise((resolve) => response.write(chunk, () => resolve())) - ); - } - function end() { - return withResponse((response) => { - return new Promise((resolve) => response.end(resolve)); - }); - } - - function withResponse(responseCb: (response: Parameters[1]) => void) { - return requestHandlerPromise.then(([request, response]) => { - return responseCb(response); - }); - } + const simulator: LlmResponseSimulator = { + status: once(async (status: number) => { + response.writeHead(status, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: 'keep-alive', + }); + }), + next: (msg) => { + const chunk = createOpenAiChunk(msg); + return write(`data: ${JSON.stringify(chunk)}\n`); + }, + write: (chunk: string) => { + return write(chunk); + }, + complete: async () => { + await write('data: [DONE]'); + await end(); + }, + error: async (error) => { + await write(`data: ${JSON.stringify({ error })}`); + await end(); + }, + }; - cb({ - status: once((status: number) => { - return withResponse((response) => { - response.writeHead(status, { - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache', - Connection: 'keep-alive', - }); - }); - }), - next: (msg) => { - const chunk = createOpenAiChunk(msg); - return write(`data: ${JSON.stringify(chunk)}\n`); - }, - write: (chunk: string) => { - return write(chunk); + outerResolve(simulator); }, - complete: async () => { - await write('data: [DONE]'); - await end(); - }, - error: async (error) => { - await write(`data: ${JSON.stringify({ error })}`); - await end(); - }, - }) - .then((result) => { - outerPromiseResolve(result); - }) - .catch((err) => { - outerPromiseReject(err); - }); + }); }), - new Promise((_, reject) => { + new Promise((_, reject) => { setTimeout(() => reject(new Error('Operation timed out')), 5000); }), ]); + + return { + waitForIntercept: () => waitForInterceptPromise, + }; } } diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts index d15d6c84bc1f8c..6601dbc213ab47 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/complete/complete.spec.ts @@ -10,6 +10,7 @@ import { omit } from 'lodash'; import { PassThrough } from 'stream'; import expect from '@kbn/expect'; import { + ConversationCreateEvent, StreamingChatResponseEvent, StreamingChatResponseEventType, } from '@kbn/observability-ai-assistant-plugin/common/conversation_complete'; @@ -20,6 +21,7 @@ import { FtrProviderContext } from '../../common/ftr_provider_context'; export default function ApiTest({ getService }: FtrProviderContext) { const supertest = getService('supertest'); + const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantAPIClient'); const COMPLETE_API_URL = `/internal/observability_ai_assistant/chat/complete`; @@ -75,153 +77,173 @@ export default function ApiTest({ getService }: FtrProviderContext) { proxy.close(); }); - it.skip('returns a streaming response from the server', async () => { - await proxy.respond(async (simulator) => { - const receivedChunks: any[] = []; + it('returns a streaming response from the server', async () => { + const interceptor = proxy.intercept('conversation', () => true); - const passThrough = new PassThrough(); + const receivedChunks: any[] = []; - supertest - .post(COMPLETE_API_URL) - .set('kbn-xsrf', 'foo') - .send({ - messages, - connectorId, - persist: false, - }) - .pipe(passThrough); + const passThrough = new PassThrough(); + + supertest + .post(COMPLETE_API_URL) + .set('kbn-xsrf', 'foo') + .send({ + messages, + connectorId, + persist: false, + }) + .pipe(passThrough); + + passThrough.on('data', (chunk) => { + receivedChunks.push(chunk.toString()); + }); + + const simulator = await interceptor.waitForIntercept(); + + await simulator.status(200); + const chunk = JSON.stringify(createOpenAiChunk('Hello')); + + await simulator.write(`data: ${chunk.substring(0, 10)}`); + await simulator.write(`${chunk.substring(10)}\n`); + await simulator.complete(); + + await new Promise((resolve) => passThrough.on('end', () => resolve())); + + const parsedChunks = receivedChunks + .join('') + .split('\n') + .map((line) => line.trim()) + .filter(Boolean) + .map((line) => JSON.parse(line) as StreamingChatResponseEvent); + + expect(parsedChunks.length).to.be(2); + expect(omit(parsedChunks[0], 'id')).to.eql({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + message: { + content: 'Hello', + }, + }); + + expect(omit(parsedChunks[1], 'id', 'message.@timestamp')).to.eql({ + type: StreamingChatResponseEventType.MessageAdd, + message: { + message: { + content: 'Hello', + role: MessageRole.Assistant, + function_call: { + name: '', + arguments: '', + trigger: MessageRole.Assistant, + }, + }, + }, + }); + }); + + describe('when creating a new conversation', async () => { + let lines: StreamingChatResponseEvent[]; + before(async () => { + const titleInterceptor = proxy.intercept( + 'title', + (body) => (JSON.parse(body) as CreateChatCompletionRequest).messages.length === 1 + ); - passThrough.on('data', (chunk) => { - receivedChunks.push(chunk.toString()); + const conversationInterceptor = proxy.intercept( + 'conversation', + (body) => (JSON.parse(body) as CreateChatCompletionRequest).messages.length !== 1 + ); + + const responsePromise = new Promise((resolve, reject) => { + supertest + .post(COMPLETE_API_URL) + .set('kbn-xsrf', 'foo') + .send({ + messages, + connectorId, + persist: true, + }) + .end((err, response) => { + if (err) { + return reject(err); + } + return resolve(response); + }); }); - await simulator.status(200); - const chunk = JSON.stringify(createOpenAiChunk('Hello')); + const [conversationSimulator, titleSimulator] = await Promise.all([ + conversationInterceptor.waitForIntercept(), + titleInterceptor.waitForIntercept(), + ]); + + await titleSimulator.status(200); + await titleSimulator.next('My generated title'); + await titleSimulator.complete(); - await simulator.write(`data: ${chunk.substring(0, 10)}`); - await simulator.write(`${chunk.substring(10)}\n`); - await simulator.complete(); + await conversationSimulator.status(200); + await conversationSimulator.next('Hello'); + await conversationSimulator.next(' again'); + await conversationSimulator.complete(); - await new Promise((resolve) => passThrough.on('end', () => resolve())); + const response = await responsePromise; - const parsedChunks = receivedChunks - .join('') + lines = String(response.body) .split('\n') .map((line) => line.trim()) .filter(Boolean) .map((line) => JSON.parse(line) as StreamingChatResponseEvent); + }); - expect(parsedChunks.length).to.be(2); - expect(omit(parsedChunks[0], 'id')).to.eql({ + it('creates a new conversation', async () => { + expect(omit(lines[0], 'id')).to.eql({ type: StreamingChatResponseEventType.ChatCompletionChunk, message: { content: 'Hello', }, }); - - expect(omit(parsedChunks[1], 'id', 'message.@timestamp')).to.eql({ + expect(omit(lines[1], 'id')).to.eql({ + type: StreamingChatResponseEventType.ChatCompletionChunk, + message: { + content: ' again', + }, + }); + expect(omit(lines[2], 'id', 'message.@timestamp')).to.eql({ type: StreamingChatResponseEventType.MessageAdd, message: { message: { - content: 'Hello', - role: MessageRole.Assistant, + content: 'Hello again', function_call: { - name: '', arguments: '', + name: '', trigger: MessageRole.Assistant, }, + role: MessageRole.Assistant, }, }, }); + expect(omit(lines[3], 'conversation.id', 'conversation.last_updated')).to.eql({ + type: StreamingChatResponseEventType.ConversationCreate, + conversation: { + title: 'My generated title', + }, + }); }); - }); - it('creates a new conversation', async () => { - const lines = await proxy.respond( - async (titleSimulator) => { - return await proxy.respond( - async (chatSimulator) => { - const responsePromise = new Promise((resolve, reject) => { - supertest - .post(COMPLETE_API_URL) - .set('kbn-xsrf', 'foo') - .send({ - messages, - connectorId, - persist: true, - }) - .end((err, response) => { - if (err) { - return reject(err); - } - return resolve(response); - }); - }); - - await titleSimulator.status(200); - await titleSimulator.next('My generated title'); - await titleSimulator.complete(); - - await chatSimulator.status(200); - await chatSimulator.next('Hello'); - await chatSimulator.next(' again'); - await chatSimulator.complete(); - - const response = await responsePromise; - - return String(response.body) - .split('\n') - .map((line) => line.trim()) - .filter(Boolean) - .map((line) => JSON.parse(line) as StreamingChatResponseEvent); - }, - async ({ data }) => { - return new Promise((resolve) => { - const body: CreateChatCompletionRequest = JSON.parse(data); - resolve(body.messages.length !== 1); - }); - } - ); - }, - async ({ data }) => { - return new Promise((resolve) => { - const body: CreateChatCompletionRequest = JSON.parse(data); - resolve(body.messages.length === 1); - }); - } - ); - - expect(omit(lines[0], 'id')).to.eql({ - type: StreamingChatResponseEventType.ChatCompletionChunk, - message: { - content: 'Hello', - }, - }); - expect(omit(lines[1], 'id')).to.eql({ - type: StreamingChatResponseEventType.ChatCompletionChunk, - message: { - content: ' again', - }, - }); - expect(omit(lines[2], 'id', 'message.@timestamp')).to.eql({ - type: StreamingChatResponseEventType.MessageAdd, - message: { - message: { - content: 'Hello again', - function_call: { - arguments: '', - name: '', - trigger: MessageRole.Assistant, + after(async () => { + const createdConversationId = lines.filter( + (line): line is ConversationCreateEvent => + line.type === StreamingChatResponseEventType.ConversationCreate + )[0]?.conversation.id; + + await observabilityAIAssistantAPIClient + .writeUser({ + endpoint: 'DELETE /internal/observability_ai_assistant/conversation/{conversationId}', + params: { + path: { + conversationId: createdConversationId, + }, }, - role: MessageRole.Assistant, - }, - }, - }); - expect(omit(lines[3], 'conversation.id', 'conversation.last_updated')).to.eql({ - type: StreamingChatResponseEventType.ConversationCreate, - conversation: { - title: 'My generated title', - }, + }) + .expect(200); }); }); From 259a1d739a20364e1a665fb9072b961d00c84ca6 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Thu, 7 Dec 2023 11:43:39 +0100 Subject: [PATCH 14/15] Replace @cfworker/json-schema with Ajv --- .../chat_function_client/index.test.ts | 75 +++++++++++++++++++ .../service/chat_function_client/index.ts | 12 +-- .../server/service/client/index.test.ts | 8 +- .../server/service/index.ts | 18 ++--- 4 files changed, 90 insertions(+), 23 deletions(-) create mode 100644 x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.test.ts diff --git a/x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.test.ts b/x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.test.ts new file mode 100644 index 00000000000000..7d34404457d246 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.test.ts @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import Ajv, { type ValidateFunction } from 'ajv'; +import { ChatFunctionClient } from '.'; +import type { ContextRegistry } from '../../../common/types'; +import type { FunctionHandlerRegistry } from '../types'; + +describe('chatFunctionClient', () => { + describe('when executing a function with invalid arguments', () => { + let client: ChatFunctionClient; + + let respondFn: jest.Mock; + + beforeEach(() => { + const contextRegistry: ContextRegistry = new Map(); + contextRegistry.set('core', { + description: '', + name: 'core', + }); + + respondFn = jest.fn().mockImplementationOnce(async () => { + return {}; + }); + + const functionRegistry: FunctionHandlerRegistry = new Map(); + functionRegistry.set('myFunction', { + respond: respondFn, + definition: { + contexts: ['core'], + description: '', + name: 'myFunction', + parameters: { + properties: { + foo: { + type: 'string', + }, + }, + required: ['foo'], + }, + }, + }); + + const validators = new Map(); + + validators.set( + 'myFunction', + new Ajv({ strict: false }).compile( + functionRegistry.get('myFunction')!.definition.parameters + ) + ); + + client = new ChatFunctionClient(contextRegistry, functionRegistry, validators); + }); + + it('throws an error', async () => { + await expect(async () => { + await client.executeFunction({ + name: 'myFunction', + args: JSON.stringify({ + foo: 0, + }), + messages: [], + signal: new AbortController().signal, + connectorId: '', + }); + }).rejects.toThrowError(`Function arguments are invalid`); + + expect(respondFn).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.ts index e4800ca61ec849..e1d49696e51133 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/chat_function_client/index.ts @@ -6,7 +6,7 @@ */ /* eslint-disable max-classes-per-file*/ -import type { Validator, OutputUnit } from '@cfworker/json-schema'; +import type { ErrorObject, ValidateFunction } from 'ajv'; import { keyBy } from 'lodash'; import type { ContextDefinition, @@ -18,7 +18,7 @@ import { filterFunctionDefinitions } from '../../../common/utils/filter_function import { FunctionHandler, FunctionHandlerRegistry } from '../types'; export class FunctionArgsValidationError extends Error { - constructor(public readonly errors: OutputUnit[]) { + constructor(public readonly errors: ErrorObject[]) { super('Function arguments are invalid'); } } @@ -27,14 +27,14 @@ export class ChatFunctionClient { constructor( private readonly contextRegistry: ContextRegistry, private readonly functionRegistry: FunctionHandlerRegistry, - private readonly validators: Map + private readonly validators: Map ) {} private validate(name: string, parameters: unknown) { const validator = this.validators.get(name)!; - const result = validator.validate(parameters); - if (!result.valid) { - throw new FunctionArgsValidationError(result.errors); + const result = validator(parameters); + if (!result) { + throw new FunctionArgsValidationError(validator.errors!); } } diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts index a0eeaf50e5ebdf..102b623acfd6d1 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts @@ -985,14 +985,10 @@ describe('Observability AI Assistant service', () => { }); response$.error(new Error('Unexpected error')); - endStreamPromise = finished(stream); - - await endStreamPromise.catch(() => {}); + await finished(stream); }); - it('appends an error and fails the stream', () => { - // expect(endStreamPromise).rejects.toBeDefined(); - + it('appends an error', () => { expect(JSON.parse(dataHandler.mock.lastCall!)).toEqual({ type: StreamingChatResponseEventType.ConversationCompletionError, error: { diff --git a/x-pack/plugins/observability_ai_assistant/server/service/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/index.ts index e42a4739206449..595217bc69e5c7 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/index.ts @@ -5,7 +5,6 @@ * 2.0. */ -import type { Validator as IValidator } from '@cfworker/json-schema'; import * as Boom from '@hapi/boom'; import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server/plugin'; import { createConcreteWriteIndex, getDataStreamAdapter } from '@kbn/alerting-plugin/server'; @@ -13,6 +12,7 @@ import type { CoreSetup, CoreStart, KibanaRequest, Logger } from '@kbn/core/serv import type { SecurityPluginStart } from '@kbn/security-plugin/server'; import { getSpaceIdFromPath } from '@kbn/spaces-plugin/common'; import type { TaskManagerSetupContract } from '@kbn/task-manager-plugin/server'; +import Ajv, { type ValidateFunction } from 'ajv'; import { once } from 'lodash'; import { ContextRegistry, @@ -34,6 +34,10 @@ import type { } from './types'; import { splitKbText } from './util/split_kb_text'; +const ajv = new Ajv({ + strict: false, +}); + function getResourceName(resource: string) { return `.kibana-observability-ai-assistant-${resource}`; } @@ -296,22 +300,14 @@ export class ObservabilityAIAssistantService { const contextRegistry: ContextRegistry = new Map(); const functionHandlerRegistry: FunctionHandlerRegistry = new Map(); - // const Validator = await import('@cfworker/json-schema').then((m) => m.Validator); - - const validators = new Map(); + const validators = new Map(); const registerContext: RegisterContextDefinition = (context) => { contextRegistry.set(context.name, context); }; const registerFunction: RegisterFunction = (definition, respond) => { - validators.set( - definition.name, - // new Validator(definition.parameters as Schema, '2020-12', true) - { - validate: () => ({ valid: true, errors: [] }), - } as any as IValidator - ); + validators.set(definition.name, ajv.compile(definition.parameters)); functionHandlerRegistry.set(definition.name, { definition, respond }); }; await Promise.all( From 6a43d452b8376db8f0b366319ca23bbb3727e839 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Thu, 7 Dec 2023 12:37:05 +0100 Subject: [PATCH 15/15] Tests for stream failure --- .../server/service/client/index.test.ts | 29 +++++++++++++++---- .../server/service/client/index.ts | 2 +- .../service/util/stream_into_observable.ts | 1 + 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts index 102b623acfd6d1..062377cda51127 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts @@ -37,6 +37,7 @@ const waitForNextWrite = async (stream: Readable): Promise => { function createLlmSimulator() { const stream = new PassThrough(); + return { stream, next: async (msg: ChunkDelta) => { @@ -237,7 +238,25 @@ describe('Observability AI Assistant service', () => { }); describe('after the LLM errors out', () => { - it('adds an error to the stream and closes it', () => {}); + beforeEach(async () => { + await llmSimulator.next({ content: ' again' }); + + llmSimulator.error(new Error('Unexpected error')); + + await finished(stream); + }); + + it('adds an error to the stream and closes it', () => { + expect(dataHandler).toHaveBeenCalledTimes(3); + + expect(JSON.parse(dataHandler.mock.lastCall!)).toEqual({ + error: { + message: 'Unexpected error', + stack: expect.any(String), + }, + type: StreamingChatResponseEventType.ConversationCompletionError, + }); + }); }); describe('when generating a title fails', () => { @@ -366,7 +385,7 @@ describe('Observability AI Assistant service', () => { }); }); - describe('when completing a conversation with an initial conversation id', () => { + describe('when completig a conversation with an initial conversation id', () => { let stream: Readable; let dataHandler: jest.Mock; @@ -537,9 +556,9 @@ describe('Observability AI Assistant service', () => { ) ); - await finished(stream); - await llmSimulator.complete(); + + await finished(stream); }); it('ends the stream and writes an error', async () => { @@ -967,8 +986,6 @@ describe('Observability AI Assistant service', () => { }); describe('if the observable errors out', () => { - let endStreamPromise: Promise; - beforeEach(async () => { response$.next({ created: 0, diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts index c12ed8d7f42480..c111af3d92d48e 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts @@ -422,7 +422,7 @@ export class ObservabilityAIAssistantClient { } const response = stream - ? (executeResult.data as Readable).pipe(new PassThrough()) + ? (executeResult.data as Readable) : (executeResult.data as CreateChatCompletionResponse); if (response instanceof PassThrough) { diff --git a/x-pack/plugins/observability_ai_assistant/server/service/util/stream_into_observable.ts b/x-pack/plugins/observability_ai_assistant/server/service/util/stream_into_observable.ts index 5ccca849e9b36f..764e39fdec1526 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/util/stream_into_observable.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/util/stream_into_observable.ts @@ -10,6 +10,7 @@ import type { Readable } from 'stream'; export function streamIntoObservable(readable: Readable): Observable { let lineBuffer = ''; + return from(readable).pipe( map((chunk: Buffer) => chunk.toString('utf-8')), map((part) => {