Skip to content

Commit

Permalink
feat: integrate regenerate API (#58)
Browse files Browse the repository at this point in the history
* feat: integrate regenerate API

Signed-off-by: SuZhou-Joe <suzhou@amazon.com>

* feat: optimize code

Signed-off-by: SuZhou-Joe <suzhou@amazon.com>

* fix: unit test error

Signed-off-by: SuZhou-Joe <suzhou@amazon.com>

* feat: optimize code

Signed-off-by: SuZhou-Joe <suzhou@amazon.com>

* feat: optimize code

Signed-off-by: SuZhou-Joe <suzhou@amazon.com>

---------

Signed-off-by: SuZhou-Joe <suzhou@amazon.com>
  • Loading branch information
SuZhou-Joe committed Dec 12, 2023
1 parent e73dcaa commit 11e5779
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 65 deletions.
8 changes: 6 additions & 2 deletions public/hooks/use_chat_actions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,19 @@ export const useChatActions = (): AssistantActions => {
}
};

const regenerate = async () => {
const regenerate = async (interactionId: string) => {
if (chatContext.sessionId) {
const abortController = new AbortController();
abortControllerRef = abortController;
chatStateDispatch({ type: 'regenerate' });

try {
const response = await core.services.http.put(`${ASSISTANT_API.REGENERATE}`, {
body: JSON.stringify({ sessionId: chatContext.sessionId }),
body: JSON.stringify({
sessionId: chatContext.sessionId,
rootAgentId: chatContext.rootAgentId,
interactionId,
}),
});

if (abortController.signal.aborted) {
Expand Down
7 changes: 7 additions & 0 deletions public/tabs/chat/messages/message_bubble.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ describe('<MessageBubble />', () => {
contentType: 'markdown',
content: 'here are the indices in your cluster: .alert',
}}
interaction={{
input: 'foo',
response: 'bar',
conversation_id: 'foo',
interaction_id: 'bar',
create_time: new Date().toLocaleString(),
}}
/>
);
expect(screen.queryAllByTitle('regenerate message')).toHaveLength(1);
Expand Down
8 changes: 4 additions & 4 deletions public/tabs/chat/messages/message_bubble.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type MessageBubbleProps = {
showActionBar: boolean;
showRegenerate?: boolean;
shouldActionBarVisibleOnHover?: boolean;
onRegenerate?: () => void;
onRegenerate?: (interactionId: string) => void;
} & (
| {
message: IMessage;
Expand Down Expand Up @@ -192,17 +192,17 @@ export const MessageBubble: React.FC<MessageBubbleProps> = React.memo((props) =>
</EuiCopy>
</EuiFlexItem>
)}
{props.showRegenerate && (
{props.showRegenerate && props.interaction?.interaction_id ? (
<EuiFlexItem grow={false}>
<EuiButtonIcon
aria-label="regenerate message"
onClick={props.onRegenerate}
onClick={() => props.onRegenerate?.(props.interaction?.interaction_id || '')}
title="regenerate message"
color="text"
iconType="refresh"
/>
</EuiFlexItem>
)}
) : null}
{showFeedback && (
// After feedback, only corresponding thumb icon will be kept and disabled.
<>
Expand Down
2 changes: 1 addition & 1 deletion public/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export interface AssistantActions {
openChatUI: (sessionId?: string) => void;
executeAction: (suggestedAction: ISuggestedAction, message: IMessage) => void;
abortAction: (sessionId?: string) => void;
regenerate: () => void;
regenerate: (interactionId: string) => void;
}

export interface AppPluginStartDependencies {
Expand Down
45 changes: 19 additions & 26 deletions server/routes/chat_routes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import {
} from '../../../../src/core/server';
import { ASSISTANT_API } from '../../common/constants/llm';
import { OllyChatService } from '../services/chat/olly_chat_service';
import { IMessage, IInput } from '../../common/types/chat_saved_object_attributes';
import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service';
import { RoutesOptions } from '../types';
import { ChatService } from '../services/chat/chat_service';
Expand Down Expand Up @@ -64,6 +63,7 @@ const regenerateRoute = {
body: schema.object({
sessionId: schema.string(),
rootAgentId: schema.string(),
interactionId: schema.string(),
}),
},
};
Expand Down Expand Up @@ -314,42 +314,35 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise<IOpenSearchDashboardsResponse<HttpResponsePayload | ResponseError>> => {
const { sessionId, rootAgentId } = request.body;
const { sessionId, rootAgentId, interactionId } = request.body;
const storageService = createStorageService(context);
let messages: IMessage[] = [];
const chatService = createChatService();

let outputs: Awaited<ReturnType<ChatService['regenerate']>> | undefined;

/**
* Get final answer from Agent framework
*/
try {
const session = await storageService.getSession(sessionId);
messages.push(...session.messages);
outputs = await chatService.regenerate({ sessionId, rootAgentId, interactionId }, context);
} catch (error) {
return response.custom({ statusCode: error.statusCode || 500, body: error.message });
context.assistant_plugin.logger.error(error);
}

const lastInputIndex = messages.findLastIndex((msg) => msg.type === 'input');
// Find last input message
const input = messages[lastInputIndex] as IInput;
// Take the messages before last input message as memory as regenerate will exclude the last outputs
messages = messages.slice(0, lastInputIndex);

/**
* Retrieve latest interactions from memory
*/
try {
const outputs = await chatService.requestLLM(
{ messages, input, sessionId, rootAgentId },
context
);
const title = input.content.substring(0, 50);
const saveMessagesResponse = await storageService.saveMessages(
title,
sessionId,
[...messages, input, ...outputs.messages].filter(
(message) => message.content !== 'AbortError'
)
);
const conversation = await storageService.getSession(sessionId);

return response.ok({
body: { ...saveMessagesResponse, title },
body: {
...conversation,
sessionId,
},
});
} catch (error) {
context.assistant_plugin.logger.warn(error);
context.assistant_plugin.logger.error(error);
return response.custom({ statusCode: error.statusCode || 500, body: error.message });
}
}
Expand Down
11 changes: 9 additions & 2 deletions server/services/chat/chat_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@ import { LLMRequestSchema } from '../../routes/chat_routes';
export interface ChatService {
requestLLM(
payload: { messages: IMessage[]; input: IInput; sessionId?: string },
context: RequestHandlerContext,
request: OpenSearchDashboardsRequest<unknown, unknown, LLMRequestSchema, 'post'>
context: RequestHandlerContext
): Promise<{
messages: IMessage[];
memoryId: string;
}>;

regenerate(
payload: { sessionId: string; interactionId: string; rootAgentId: string },
context: RequestHandlerContext
): Promise<{
messages: IMessage[];
memoryId: string;
Expand Down
176 changes: 176 additions & 0 deletions server/services/chat/olly_chat_service.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import { OllyChatService } from './olly_chat_service';
import { CoreRouteHandlerContext } from '../../../../../src/core/server/core_route_handler_context';
import { coreMock, httpServerMock } from '../../../../../src/core/server/mocks';
import { loggerMock } from '../../../../../src/core/server/logging/logger.mock';

describe('OllyChatService', () => {
const ollyChatService = new OllyChatService();
const coreContext = new CoreRouteHandlerContext(
coreMock.createInternalStart(),
httpServerMock.createOpenSearchDashboardsRequest()
);
const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport
.request as jest.Mock;
const contextMock = {
core: coreContext,
assistant_plugin: {
logger: loggerMock.create(),
},
};
beforeEach(() => {
mockedTransport.mockClear();
});
it('requestLLM should invoke client call with correct params', async () => {
mockedTransport.mockImplementationOnce(() => {
return {
body: {
inference_results: [
{
output: [
{
name: 'memory_id',
result: 'foo',
},
],
},
],
},
};
});
const result = await ollyChatService.requestLLM(
{
messages: [],
input: {
type: 'input',
contentType: 'text',
content: 'content',
},
sessionId: '',
rootAgentId: 'rootAgentId',
},
contextMock
);
expect(mockedTransport.mock.calls).toMatchInlineSnapshot(`
Array [
Array [
Object {
"body": Object {
"parameters": Object {
"question": "content",
"verbose": true,
},
},
"method": "POST",
"path": "/_plugins/_ml/agents/rootAgentId/_execute",
},
Object {
"maxRetries": 0,
"requestTimeout": 300000,
},
],
]
`);
expect(result).toMatchInlineSnapshot(`
Object {
"memoryId": "foo",
"messages": Array [],
}
`);
});

it('requestLLM should throw error when transport.request throws error', async () => {
mockedTransport.mockImplementationOnce(() => {
throw new Error('error');
});
expect(
ollyChatService.requestLLM(
{
messages: [],
input: {
type: 'input',
contentType: 'text',
content: 'content',
},
sessionId: '',
rootAgentId: 'rootAgentId',
},
contextMock
)
).rejects.toMatchInlineSnapshot(`[Error: error]`);
});

it('regenerate should invoke client call with correct params', async () => {
mockedTransport.mockImplementationOnce(() => {
return {
body: {
inference_results: [
{
output: [
{
name: 'memory_id',
result: 'foo',
},
],
},
],
},
};
});
const result = await ollyChatService.regenerate(
{
sessionId: 'sessionId',
rootAgentId: 'rootAgentId',
interactionId: 'interactionId',
},
contextMock
);
expect(mockedTransport.mock.calls).toMatchInlineSnapshot(`
Array [
Array [
Object {
"body": Object {
"parameters": Object {
"memory_id": "sessionId",
"regenerate_interaction_id": "interactionId",
"verbose": true,
},
},
"method": "POST",
"path": "/_plugins/_ml/agents/rootAgentId/_execute",
},
Object {
"maxRetries": 0,
"requestTimeout": 300000,
},
],
]
`);
expect(result).toMatchInlineSnapshot(`
Object {
"memoryId": "foo",
"messages": Array [],
}
`);
});

it('regenerate should throw error when transport.request throws error', async () => {
mockedTransport.mockImplementationOnce(() => {
throw new Error('error');
});
expect(
ollyChatService.regenerate(
{
sessionId: 'sessionId',
rootAgentId: 'rootAgentId',
interactionId: 'interactionId',
},
contextMock
)
).rejects.toMatchInlineSnapshot(`[Error: error]`);
});
});
Loading

0 comments on commit 11e5779

Please sign in to comment.