diff --git a/package.json b/package.json index e97faff9e78..5c0326b27d2 100644 --- a/package.json +++ b/package.json @@ -151,7 +151,7 @@ "@types/ua-parser-js": "^0.7", "@types/uuid": "^9", "@umijs/lint": "^4", - "@vitest/coverage-v8": "0.34.6", + "@vitest/coverage-v8": "^1", "commitlint": "^18", "consola": "^3", "dpdm": "^3", @@ -175,7 +175,7 @@ "typescript": "^5", "unified": "^11", "unist-util-visit": "^5", - "vitest": "0.34.6", + "vitest": "^1", "vitest-canvas-mock": "^0.3.3" }, "publishConfig": { diff --git a/src/database/schemas/message.ts b/src/database/schemas/message.ts index 4b34bc11fd1..5e5c67474c8 100644 --- a/src/database/schemas/message.ts +++ b/src/database/schemas/message.ts @@ -11,7 +11,7 @@ const PluginSchema = z.object({ identifier: z.string(), arguments: z.string(), apiName: z.string(), - type: z.enum(['default', 'standalone', 'builtin']).default('default'), + type: z.enum(['default', 'markdown', 'standalone', 'builtin']).default('default'), }); export const DB_MessageSchema = z.object({ diff --git a/src/features/Conversation/ChatList/Plugins/Render/BuiltinType/index.tsx b/src/features/Conversation/ChatList/Plugins/Render/BuiltinType/index.tsx index 586adfe780a..8280930a85b 100644 --- a/src/features/Conversation/ChatList/Plugins/Render/BuiltinType/index.tsx +++ b/src/features/Conversation/ChatList/Plugins/Render/BuiltinType/index.tsx @@ -1,5 +1,4 @@ import { memo } from 'react'; -import { Flexbox } from 'react-layout-kit'; import { BuiltinToolsRenders } from '@/tools/renders'; @@ -17,13 +16,7 @@ const BuiltinType = memo(({ content, id, identifier, loading } const { isJSON, data } = useParseContent(content); if (!isJSON) { - return ( - loading && ( - - - - ) - ); + return loading && ; } const Render = BuiltinToolsRenders[identifier || '']; diff --git a/src/features/Conversation/ChatList/Plugins/Render/DefaultType/index.tsx b/src/features/Conversation/ChatList/Plugins/Render/DefaultType/index.tsx index 88b8492b2fa..baf83444324 100644 --- a/src/features/Conversation/ChatList/Plugins/Render/DefaultType/index.tsx +++ b/src/features/Conversation/ChatList/Plugins/Render/DefaultType/index.tsx @@ -1,7 +1,6 @@ import { Skeleton } from 'antd'; import dynamic from 'next/dynamic'; import { Suspense, memo } from 'react'; -import { Flexbox } from 'react-layout-kit'; import { useToolStore } from '@/store/tool'; import { pluginSelectors } from '@/store/tool/selectors'; @@ -24,13 +23,7 @@ const PluginDefaultType = memo(({ content, name, loading const { isJSON, data } = useParseContent(content); if (!isJSON) { - return ( - loading && ( - - - - ) - ); + return loading && ; } if (!manifest?.ui) return; diff --git a/src/features/Conversation/ChatList/Plugins/Render/MarkdownType/index.tsx b/src/features/Conversation/ChatList/Plugins/Render/MarkdownType/index.tsx new file mode 100644 index 00000000000..6fe73ee53a8 --- /dev/null +++ b/src/features/Conversation/ChatList/Plugins/Render/MarkdownType/index.tsx @@ -0,0 +1,17 @@ +import { Markdown } from '@lobehub/ui'; +import { memo } from 'react'; + +import Loading from '../Loading'; + +export interface PluginMarkdownTypeProps { + content: string; + loading?: boolean; +} + +const PluginMarkdownType = memo(({ content, loading }) => { + if (loading) return ; + + return {content}; +}); + +export default PluginMarkdownType; diff --git a/src/features/Conversation/ChatList/Plugins/Render/StandaloneType/Iframe.tsx b/src/features/Conversation/ChatList/Plugins/Render/StandaloneType/Iframe.tsx index e7b0c9d50f9..ca4d9dd91c2 100644 --- a/src/features/Conversation/ChatList/Plugins/Render/StandaloneType/Iframe.tsx +++ b/src/features/Conversation/ChatList/Plugins/Render/StandaloneType/Iframe.tsx @@ -9,10 +9,12 @@ import { pluginSelectors } from '@/store/tool/selectors'; import { useOnPluginReadyForInteraction } from '../utils/iframeOnReady'; import { + useOnPluginCreateAssistantMessage, useOnPluginFetchMessage, useOnPluginFetchPluginSettings, useOnPluginFetchPluginState, useOnPluginFillContent, + useOnPluginTriggerAIMessage, } from '../utils/listenToPlugin'; import { useOnPluginSettingsUpdate } from '../utils/pluginSettings'; import { useOnPluginStateUpdate } from '../utils/pluginState'; @@ -118,6 +120,21 @@ const IFrameRender = memo(({ url, id, payload, width = 600, h updatePluginSettings(payload?.identifier, value); }); + // when plugin want to trigger AI message + const triggerAIMessage = useChatStore((s) => s.triggerAIMessage); + useOnPluginTriggerAIMessage((messageId) => { + // we need to know which message to trigger + if (messageId !== id) return; + + triggerAIMessage(id); + }); + + // when plugin want to create an assistant message + const createAssistantMessage = useChatStore((s) => s.createAssistantMessageByPlugin); + useOnPluginCreateAssistantMessage((content) => { + createAssistantMessage(content, id); + }); + return ( <> {loading && } diff --git a/src/features/Conversation/ChatList/Plugins/Render/index.tsx b/src/features/Conversation/ChatList/Plugins/Render/index.tsx index cbc286a7da2..6bd8ebdade5 100644 --- a/src/features/Conversation/ChatList/Plugins/Render/index.tsx +++ b/src/features/Conversation/ChatList/Plugins/Render/index.tsx @@ -5,6 +5,7 @@ import { LobeToolRenderType } from '@/types/tool'; import BuiltinType from '././BuiltinType'; import DefaultType from './DefaultType'; +import Markdown from './MarkdownType'; import Standalone from './StandaloneType'; export interface PluginRenderProps { @@ -27,6 +28,10 @@ const PluginRender = memo( return ; } + case 'markdown': { + return ; + } + default: { return ; } diff --git a/src/features/Conversation/ChatList/Plugins/Render/utils/listenToPlugin.test.ts b/src/features/Conversation/ChatList/Plugins/Render/utils/listenToPlugin.test.ts index 7fcdede6d8e..adb158f827f 100644 --- a/src/features/Conversation/ChatList/Plugins/Render/utils/listenToPlugin.test.ts +++ b/src/features/Conversation/ChatList/Plugins/Render/utils/listenToPlugin.test.ts @@ -3,10 +3,12 @@ import { renderHook } from '@testing-library/react'; import { afterEach, describe, expect, it, vi } from 'vitest'; import { + useOnPluginCreateAssistantMessage, useOnPluginFetchMessage, useOnPluginFetchPluginSettings, useOnPluginFetchPluginState, useOnPluginFillContent, + useOnPluginTriggerAIMessage, } from './listenToPlugin'; afterEach(() => { @@ -102,3 +104,61 @@ describe('useOnPluginFetchPluginSettings', () => { expect(mockOnRequest).toHaveBeenCalled(); }); }); + +describe('useOnPluginTriggerAIMessage', () => { + it('calls callback with id when a triggerAIMessage is received', () => { + const mockCallback = vi.fn(); + renderHook(() => useOnPluginTriggerAIMessage(mockCallback)); + + const testId = 'testId'; + const event = new MessageEvent('message', { + data: { type: PluginChannel.triggerAIMessage, id: testId }, + }); + + window.dispatchEvent(event); + + expect(mockCallback).toHaveBeenCalledWith(testId); + }); + + it('does not call callback for other message types', () => { + const mockCallback = vi.fn(); + renderHook(() => useOnPluginTriggerAIMessage(mockCallback)); + + const event = new MessageEvent('message', { + data: { type: 'otherMessageType', id: 'testId' }, + }); + + window.dispatchEvent(event); + + expect(mockCallback).not.toHaveBeenCalled(); + }); +}); + +describe('useOnPluginCreateAssistantMessage', () => { + it('calls callback with content when a createAssistantMessage is received', () => { + const mockCallback = vi.fn(); + renderHook(() => useOnPluginCreateAssistantMessage(mockCallback)); + + const testContent = 'testContent'; + const event = new MessageEvent('message', { + data: { type: PluginChannel.createAssistantMessage, content: testContent }, + }); + + window.dispatchEvent(event); + + expect(mockCallback).toHaveBeenCalledWith(testContent); + }); + + it('does not call callback for other message types', () => { + const mockCallback = vi.fn(); + renderHook(() => useOnPluginCreateAssistantMessage(mockCallback)); + + const event = new MessageEvent('message', { + data: { type: 'otherMessageType', content: 'testContent' }, + }); + + window.dispatchEvent(event); + + expect(mockCallback).not.toHaveBeenCalled(); + }); +}); diff --git a/src/features/Conversation/ChatList/Plugins/Render/utils/listenToPlugin.ts b/src/features/Conversation/ChatList/Plugins/Render/utils/listenToPlugin.ts index 04782fcd7ce..ff905cd8d29 100644 --- a/src/features/Conversation/ChatList/Plugins/Render/utils/listenToPlugin.ts +++ b/src/features/Conversation/ChatList/Plugins/Render/utils/listenToPlugin.ts @@ -63,3 +63,33 @@ export const useOnPluginFetchPluginSettings = (onRequest: () => void) => { }; }, []); }; + +export const useOnPluginTriggerAIMessage = (callback: (id: string) => void) => { + useEffect(() => { + const fn = (e: MessageEvent) => { + if (e.data.type === PluginChannel.triggerAIMessage) { + callback(e.data.id); + } + }; + + window.addEventListener('message', fn); + return () => { + window.removeEventListener('message', fn); + }; + }, []); +}; + +export const useOnPluginCreateAssistantMessage = (callback: (content: string) => void) => { + useEffect(() => { + const fn = (e: MessageEvent) => { + if (e.data.type === PluginChannel.createAssistantMessage) { + callback(e.data.content); + } + }; + + window.addEventListener('message', fn); + return () => { + window.removeEventListener('message', fn); + }; + }, []); +}; diff --git a/src/services/__tests__/__snapshots__/plugin.test.ts.snap b/src/services/__tests__/__snapshots__/plugin.test.ts.snap index 73db3e32a11..7dfe6d4f04d 100644 --- a/src/services/__tests__/__snapshots__/plugin.test.ts.snap +++ b/src/services/__tests__/__snapshots__/plugin.test.ts.snap @@ -33,17 +33,17 @@ General guidelines: - Inform users if information is not from Wolfram endpoints. - Display image URLs with Markdown syntax: ![URL] - ALWAYS use this exponent notation: \`6*10^14\`, NEVER \`6e14\`. -- ALWAYS use {\\"input\\": query} structure for queries to Wolfram endpoints; \`query\` must ONLY be a single-line string. -- ALWAYS use proper Markdown formatting for all math, scientific, and chemical formulas, symbols, etc.: '$$\\\\n[expression]\\\\n$$' for standalone cases and '\\\\( [expression] \\\\)' when inline. +- ALWAYS use {"input": query} structure for queries to Wolfram endpoints; \`query\` must ONLY be a single-line string. +- ALWAYS use proper Markdown formatting for all math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline. - Format inline Wolfram Language code with Markdown code formatting. - Never mention your knowledge cutoff date; Wolfram may return more recent data. getWolframAlphaResults guidelines: - Understands natural language queries about entities in chemistry, physics, geography, history, art, astronomy, and more. - Performs mathematical calculations, date and unit conversions, formula solving, etc. -- Convert inputs to simplified keyword queries whenever possible (e.g. convert \\"how many people live in France\\" to \\"France population\\"). +- Convert inputs to simplified keyword queries whenever possible (e.g. convert "how many people live in France" to "France population"). - Use ONLY single-letter variable names, with or without integer subscript (e.g., n, n1, n_1). - Use named physical constants (e.g., 'speed of light') without numerical substitution. -- Include a space between compound units (e.g., \\"Ω m\\" for \\"ohm*meter\\"). +- Include a space between compound units (e.g., "Ω m" for "ohm*meter"). - To solve for a variable in an equation with units, consider solving a corresponding equation without units; exclude counting units (e.g., books), include genuine units (e.g., kg). - If data for multiple properties is needed, make separate calls for each property. - If a Wolfram Alpha result is not relevant to the query: @@ -55,22 +55,22 @@ getWolframCloudResults guidelines: - Accepts only syntactically correct Wolfram Language code. - Performs complex calculations, data analysis, plotting, data import, and information retrieval. - Before writing code that uses Entity, EntityProperty, EntityClass, etc. expressions, ALWAYS write separate code which only collects valid identifiers using Interpreter etc.; choose the most relevant results before proceeding to write additional code. Examples: - -- Find the EntityType that represents countries: \`Interpreter[\\"EntityType\\",AmbiguityFunction->All][\\"countries\\"]\`. - -- Find the Entity for the Empire State Building: \`Interpreter[\\"Building\\",AmbiguityFunction->All][\\"empire state\\"]\`. - -- EntityClasses: Find the \\"Movie\\" entity class for Star Trek movies: \`Interpreter[\\"MovieClass\\",AmbiguityFunction->All][\\"star trek\\"]\`. - -- Find EntityProperties associated with \\"weight\\" of \\"Element\\" entities: \`Interpreter[Restricted[\\"EntityProperty\\", \\"Element\\"],AmbiguityFunction->All][\\"weight\\"]\`. - -- If all else fails, try to find any valid Wolfram Language representation of a given input: \`SemanticInterpretation[\\"skyscrapers\\",_,Hold,AmbiguityFunction->All]\`. - -- Prefer direct use of entities of a given type to their corresponding typeData function (e.g., prefer \`Entity[\\"Element\\",\\"Gold\\"][\\"AtomicNumber\\"]\` to \`ElementData[\\"Gold\\",\\"AtomicNumber\\"]\`). + -- Find the EntityType that represents countries: \`Interpreter["EntityType",AmbiguityFunction->All]["countries"]\`. + -- Find the Entity for the Empire State Building: \`Interpreter["Building",AmbiguityFunction->All]["empire state"]\`. + -- EntityClasses: Find the "Movie" entity class for Star Trek movies: \`Interpreter["MovieClass",AmbiguityFunction->All]["star trek"]\`. + -- Find EntityProperties associated with "weight" of "Element" entities: \`Interpreter[Restricted["EntityProperty", "Element"],AmbiguityFunction->All]["weight"]\`. + -- If all else fails, try to find any valid Wolfram Language representation of a given input: \`SemanticInterpretation["skyscrapers",_,Hold,AmbiguityFunction->All]\`. + -- Prefer direct use of entities of a given type to their corresponding typeData function (e.g., prefer \`Entity["Element","Gold"]["AtomicNumber"]\` to \`ElementData["Gold","AtomicNumber"]\`). - When composing code: -- Use batching techniques to retrieve data for multiple entities in a single call, if applicable. -- Use Association to organize and manipulate data when appropriate. -- Optimize code for performance and minimize the number of calls to external sources (e.g., the Wolfram Knowledgebase) -- Use only camel case for variable names (e.g., variableName). - -- Use ONLY double quotes around all strings, including plot labels, etc. (e.g., \`PlotLegends -> {\\"sin(x)\\", \\"cos(x)\\", \\"tan(x)\\"}\`). + -- Use ONLY double quotes around all strings, including plot labels, etc. (e.g., \`PlotLegends -> {"sin(x)", "cos(x)", "tan(x)"}\`). -- Avoid use of QuantityMagnitude. - -- If unevaluated Wolfram Language symbols appear in API results, use \`EntityValue[Entity[\\"WolframLanguageSymbol\\",symbol],{\\"PlaintextUsage\\",\\"Options\\"}]\` to validate or retrieve usage information for relevant symbols; \`symbol\` may be a list of symbols. + -- If unevaluated Wolfram Language symbols appear in API results, use \`EntityValue[Entity["WolframLanguageSymbol",symbol],{"PlaintextUsage","Options"}]\` to validate or retrieve usage information for relevant symbols; \`symbol\` may be a list of symbols. -- Apply Evaluate to complex expressions like integrals before plotting (e.g., \`Plot[Evaluate[Integrate[...]]]\`). -- Remove all comments and formatting from code passed to the \\"input\\" parameter; for example: instead of \`square[x_] := Module[{result},\\\\n result = x^2 (* Calculate the square *)\\\\n]\`, send \`square[x_]:=Module[{result},result=x^2]\`. +- Remove all comments and formatting from code passed to the "input" parameter; for example: instead of \`square[x_] := Module[{result},\\n result = x^2 (* Calculate the square *)\\n]\`, send \`square[x_]:=Module[{result},result=x^2]\`. - In ALL responses that involve code, write ALL code in Wolfram Language; create Wolfram Language functions even if an implementation is already well known in another language. ", "type": "default", diff --git a/src/store/chat/slices/message/action.ts b/src/store/chat/slices/message/action.ts index db4b696d23c..123b7a0c66a 100644 --- a/src/store/chat/slices/message/action.ts +++ b/src/store/chat/slices/message/action.ts @@ -8,7 +8,6 @@ import { StateCreator } from 'zustand/vanilla'; import { GPT4_VISION_MODEL_DEFAULT_MAX_TOKENS } from '@/const/llm'; import { LOADING_FLAT, isFunctionMessageAtStart, testFunctionMessageAtEnd } from '@/const/message'; import { CreateMessageParams } from '@/database/models/message'; -import { DB_Message } from '@/database/schemas/message'; import { chatService } from '@/services/chat'; import { messageService } from '@/services/message'; import { topicService } from '@/services/topic'; @@ -238,7 +237,7 @@ export const chatMessage: StateCreator< const { model } = getAgentConfig(); // 1. Add an empty message to place the AI response - const assistantMessage: DB_Message = { + const assistantMessage: CreateMessageParams = { role: 'assistant', content: LOADING_FLAT, fromModel: model, @@ -383,7 +382,7 @@ export const chatMessage: StateCreator< toggleChatLoading(false, undefined, n('generateMessage(end)') as string); // also exist message like this: - // 请稍等,我帮您查询一下。{"function_call": {"name": "plugin-identifier____recommendClothes____standalone", "arguments": "{\n "mood": "",\n "gender": "man"\n}"}} + // 请稍等,我帮您查询一下。{"tool_calls": {"name": "plugin-identifier____recommendClothes____standalone", "arguments": "{\n "mood": "",\n "gender": "man"\n}"}} if (!isFunctionCall) { const { content, valid } = testFunctionMessageAtEnd(output); diff --git a/src/store/chat/slices/tool/action.test.ts b/src/store/chat/slices/tool/action.test.ts index 41cbf08251b..ce89561518a 100644 --- a/src/store/chat/slices/tool/action.test.ts +++ b/src/store/chat/slices/tool/action.test.ts @@ -9,6 +9,7 @@ import { chatSelectors } from '@/store/chat/selectors'; import { useChatStore } from '@/store/chat/store'; import { useToolStore } from '@/store/tool'; import { pluginSelectors } from '@/store/tool/selectors'; +import { ChatPluginPayload } from '@/types/message'; import { LobeTool } from '@/types/tool'; // Mock messageService 和 chatSelectors @@ -18,6 +19,7 @@ vi.mock('@/services/message', () => ({ updateMessage: vi.fn(), updateMessageError: vi.fn(), updateMessagePluginState: vi.fn(), + create: vi.fn(), }, })); vi.mock('@/services/chat', () => ({ @@ -32,6 +34,10 @@ vi.mock('@/store/chat/selectors', () => ({ getMessageById: vi.fn(), }, })); +beforeEach(() => { + // 在每个测试之前重置模拟函数 + vi.resetAllMocks(); +}); describe('ChatPluginAction', () => { describe('fillPluginMessageContent', () => { @@ -302,4 +308,235 @@ describe('ChatPluginAction', () => { expect(initialState.refreshMessages).toHaveBeenCalled(); }); }); + + describe('createAssistantMessageByPlugin', () => { + it('should create an assistant message and refresh messages', async () => { + // 模拟 messageService.create 方法的实现 + (messageService.create as Mock).mockResolvedValue({}); + + // 设置初始状态并模拟 refreshMessages 方法 + const initialState = { + refreshMessages: vi.fn(), + activeId: 'session-id', + activeTopicId: 'topic-id', + }; + useChatStore.setState(initialState); + + const { result } = renderHook(() => useChatStore()); + + const content = 'Test content'; + const parentId = 'parent-message-id'; + + await act(async () => { + await result.current.createAssistantMessageByPlugin(content, parentId); + }); + + // 验证 messageService.create 是否被带有正确参数调用 + expect(messageService.create).toHaveBeenCalledWith({ + content, + parentId, + role: 'assistant', + sessionId: initialState.activeId, + topicId: initialState.activeTopicId, + }); + + // 验证 refreshMessages 是否被调用 + expect(result.current.refreshMessages).toHaveBeenCalled(); + }); + + it('should handle errors when message creation fails', async () => { + // 模拟 messageService.create 方法,使其抛出错误 + const errorMessage = 'Failed to create message'; + (messageService.create as Mock).mockRejectedValue(new Error(errorMessage)); + + // 设置初始状态并模拟 refreshMessages 方法 + const initialState = { + refreshMessages: vi.fn(), + activeId: 'session-id', + activeTopicId: 'topic-id', + }; + useChatStore.setState(initialState); + + const { result } = renderHook(() => useChatStore()); + + const content = 'Test content'; + const parentId = 'parent-message-id'; + + await act(async () => { + await expect( + result.current.createAssistantMessageByPlugin(content, parentId), + ).rejects.toThrow(errorMessage); + }); + + // 验证 messageService.create 是否被带有正确参数调用 + expect(messageService.create).toHaveBeenCalledWith({ + content, + parentId, + role: 'assistant', + sessionId: initialState.activeId, + topicId: initialState.activeTopicId, + }); + + // 验证 refreshMessages 是否没有被调用 + expect(result.current.refreshMessages).not.toHaveBeenCalled(); + }); + }); + + describe('invokeBuiltinTool', () => { + it('should invoke a builtin tool and update message content ,then run text2image', async () => { + const payload = { + apiName: 'text2image', + arguments: JSON.stringify({ key: 'value' }), + } as ChatPluginPayload; + + const messageId = 'message-id'; + const toolResponse = JSON.stringify({ abc: 'data' }); + + useToolStore.setState({ + invokeBuiltinTool: vi.fn().mockResolvedValue(toolResponse), + }); + + useChatStore.setState({ + toggleChatLoading: vi.fn(), + updateMessageContent: vi.fn(), + text2image: vi.fn(), + }); + + const { result } = renderHook(() => useChatStore()); + + await act(async () => { + await result.current.invokeBuiltinTool(messageId, payload); + }); + + // Verify that the builtin tool was invoked with the correct arguments + expect(useToolStore.getState().invokeBuiltinTool).toHaveBeenCalledWith( + payload.apiName, + JSON.parse(payload.arguments), + ); + + // Verify that the message content was updated with the tool response + expect(result.current.updateMessageContent).toHaveBeenCalledWith(messageId, toolResponse); + + // Verify that loading was toggled correctly + expect(result.current.toggleChatLoading).toHaveBeenCalledWith( + true, + messageId, + expect.any(String), + ); + expect(result.current.toggleChatLoading).toHaveBeenCalledWith(false); + expect(useChatStore.getState().text2image).toHaveBeenCalled(); + }); + + it('should invoke a builtin tool and update message content', async () => { + const payload = { + apiName: 'text2image', + arguments: JSON.stringify({ key: 'value' }), + } as ChatPluginPayload; + + const messageId = 'message-id'; + const toolResponse = 'Builtin tool response'; + + act(() => { + useToolStore.setState({ + invokeBuiltinTool: vi.fn().mockResolvedValue(toolResponse), + text2image: vi.fn(), + }); + + useChatStore.setState({ + toggleChatLoading: vi.fn(), + text2image: vi.fn(), + updateMessageContent: vi.fn(), + }); + }); + const { result } = renderHook(() => useChatStore()); + + await act(async () => { + await result.current.invokeBuiltinTool(messageId, payload); + }); + + // Verify that the builtin tool was invoked with the correct arguments + expect(useToolStore.getState().invokeBuiltinTool).toHaveBeenCalledWith( + payload.apiName, + JSON.parse(payload.arguments), + ); + + // Verify that the message content was updated with the tool response + expect(result.current.updateMessageContent).toHaveBeenCalledWith(messageId, toolResponse); + + // Verify that loading was toggled correctly + expect(result.current.toggleChatLoading).toHaveBeenCalledWith( + true, + messageId, + expect.any(String), + ); + expect(result.current.toggleChatLoading).toHaveBeenCalledWith(false); + expect(useChatStore.getState().text2image).not.toHaveBeenCalled(); + }); + it('should handle errors when invoking a builtin tool fails', async () => { + const payload = { + apiName: 'builtinApi', + arguments: JSON.stringify({ key: 'value' }), + } as ChatPluginPayload; + + const messageId = 'message-id'; + const error = new Error('Builtin tool failed'); + + useToolStore.setState({ + invokeBuiltinTool: vi.fn().mockRejectedValue(error), + }); + + useChatStore.setState({ + toggleChatLoading: vi.fn(), + updateMessageContent: vi.fn(), + text2image: vi.fn(), + refreshMessages: vi.fn(), + }); + + const { result } = renderHook(() => useChatStore()); + + await act(async () => { + await result.current.invokeBuiltinTool(messageId, payload); + }); + + // Verify that loading was toggled correctly + expect(result.current.toggleChatLoading).toHaveBeenCalledWith( + true, + messageId, + expect.any(String), + ); + expect(result.current.toggleChatLoading).toHaveBeenCalledWith(false); + + // Verify that the message content was not updated + expect(result.current.updateMessageContent).not.toHaveBeenCalled(); + + // Verify that messages were not refreshed + expect(result.current.refreshMessages).not.toHaveBeenCalled(); + expect(useChatStore.getState().text2image).not.toHaveBeenCalled(); + }); + }); + + describe('invokeMarkdownTypePlugin', () => { + it('should invoke a markdown type plugin', async () => { + const payload = { + apiName: 'markdownApi', + identifier: 'abc', + type: 'markdown', + arguments: JSON.stringify({ key: 'value' }), + } as ChatPluginPayload; + const messageId = 'message-id'; + + useChatStore.setState({ + runPluginApi: vi.fn().mockResolvedValue('Markdown response'), + }); + + const { result } = renderHook(() => useChatStore()); + + await act(async () => { + await result.current.invokeMarkdownTypePlugin(messageId, payload); + }); + + // Verify that the markdown type plugin was invoked + expect(result.current.runPluginApi).toHaveBeenCalledWith(messageId, payload); + }); + }); }); diff --git a/src/store/chat/slices/tool/action.ts b/src/store/chat/slices/tool/action.ts index 3cf1aa931a6..81baa7afe84 100644 --- a/src/store/chat/slices/tool/action.ts +++ b/src/store/chat/slices/tool/action.ts @@ -2,6 +2,7 @@ import { Md5 } from 'ts-md5'; import { StateCreator } from 'zustand/vanilla'; import { PLUGIN_SCHEMA_API_MD5_PREFIX, PLUGIN_SCHEMA_SEPARATOR } from '@/const/plugin'; +import { CreateMessageParams } from '@/database/models/message'; import { chatService } from '@/services/chat'; import { messageService } from '@/services/message'; import { ChatStore } from '@/store/chat/store'; @@ -16,9 +17,13 @@ import { chatSelectors } from '../../selectors'; const n = setNamespace('plugin'); export interface ChatPluginAction { + createAssistantMessageByPlugin: (content: string, parentId: string) => Promise; fillPluginMessageContent: (id: string, content: string) => Promise; invokeBuiltinTool: (id: string, payload: ChatPluginPayload) => Promise; invokeDefaultTypePlugin: (id: string, payload: any) => Promise; + invokeMarkdownTypePlugin: (id: string, payload: ChatPluginPayload) => Promise; + runPluginApi: (id: string, payload: ChatPluginPayload) => Promise; + triggerAIMessage: (id: string) => Promise; triggerFunctionCall: (id: string) => Promise; updatePluginState: (id: string, key: string, value: any) => Promise; } @@ -29,34 +34,77 @@ export const chatPlugin: StateCreator< [], ChatPluginAction > = (set, get) => ({ + createAssistantMessageByPlugin: async (content, parentId) => { + const newMessage: CreateMessageParams = { + content, + parentId, + role: 'assistant', + sessionId: get().activeId, + topicId: get().activeTopicId, // if there is activeTopicId,then add it to topicId + }; + + await messageService.create(newMessage); + await get().refreshMessages(); + }, + fillPluginMessageContent: async (id, content) => { - const { coreProcessMessage, updateMessageContent } = get(); + const { triggerAIMessage, updateMessageContent } = get(); await updateMessageContent(id, content); - const chats = chatSelectors.currentChats(get()); - await coreProcessMessage(chats, id); + await triggerAIMessage(id); }, + invokeBuiltinTool: async (id, payload) => { const { toggleChatLoading, updateMessageContent } = get(); const params = JSON.parse(payload.arguments); toggleChatLoading(true, id, n('invokeBuiltinTool') as string); - const data = await useToolStore.getState().invokeBuiltinTool(payload.apiName, params); + let data; + try { + data = await useToolStore.getState().invokeBuiltinTool(payload.apiName, params); + } catch (error) { + console.log(error); + } toggleChatLoading(false); - if (data) { - await updateMessageContent(id, data); - } + if (!data) return; + + await updateMessageContent(id, data); // postToolCalling // @ts-ignore const { [payload.apiName]: action } = get(); - if (!action || !data) return; + if (!action) return; + + let content; - await action(id, JSON.parse(data)); + try { + content = JSON.parse(data); + } catch {} + + if (!content) return; + + await action(id, content); }, + invokeDefaultTypePlugin: async (id, payload) => { - const { updateMessageContent, refreshMessages, coreProcessMessage, toggleChatLoading } = get(); + const { runPluginApi, triggerAIMessage } = get(); + + const data = await runPluginApi(id, payload); + + if (!data) return; + + await triggerAIMessage(id); + }, + + invokeMarkdownTypePlugin: async (id, payload) => { + const { runPluginApi } = get(); + + await runPluginApi(id, payload); + }, + + runPluginApi: async (id, payload) => { + const { updateMessageContent, refreshMessages, toggleChatLoading } = get(); let data: string; try { @@ -73,20 +121,32 @@ export const chatPlugin: StateCreator< data = ''; } + toggleChatLoading(false); // 如果报错则结束了 if (!data) return; await updateMessageContent(id, data); + return data; + }, + + triggerAIMessage: async (id) => { + const { coreProcessMessage } = get(); const chats = chatSelectors.currentChats(get()); await coreProcessMessage(chats, id); }, + triggerFunctionCall: async (id) => { const message = chatSelectors.getMessageById(id)(get()); if (!message) return; - const { invokeDefaultTypePlugin, invokeBuiltinTool, refreshMessages } = get(); + const { + invokeDefaultTypePlugin, + invokeMarkdownTypePlugin, + invokeBuiltinTool, + refreshMessages, + } = get(); let payload = { apiName: '', identifier: '' } as ChatPluginPayload; @@ -135,10 +195,16 @@ export const chatPlugin: StateCreator< // TODO: need to auth user's settings break; } + case 'markdown': { + await invokeMarkdownTypePlugin(id, payload); + break; + } + case 'builtin': { await invokeBuiltinTool(id, payload); break; } + default: { await invokeDefaultTypePlugin(id, payload); } diff --git a/src/types/message/tools.ts b/src/types/message/tools.ts index ca87eba7bc5..39a7cb92fb5 100644 --- a/src/types/message/tools.ts +++ b/src/types/message/tools.ts @@ -1,6 +1,8 @@ +import { LobeToolRenderType } from '@/types/tool'; + export interface ChatPluginPayload { apiName: string; arguments: string; identifier: string; - type: 'standalone' | 'default' | 'builtin'; + type: LobeToolRenderType; }