Skip to content

Commit

Permalink
[Search] [Chat Playground] handle when the ActionLLM is not a ChatMod…
Browse files Browse the repository at this point in the history
…el (elastic#183931)

## Summary

Two action based LLMs: `ActionsClientChatOpenAI` and `ActionsClientLlm`.
`ActionsClientChatOpenAI` is based on the ChatModel LLM,
`ActionsClientLlm` is a prompt based model. The callbacks are different
when using a ChatModel vs LLMModel. Token count is done on the
ChatModelStart callback. This meant the token count didn't happen for
LLMModel based actions (Bedrock).

To fix i listen on both callbacks.    

### Checklist

Delete any items that are not applicable to this PR.

- [ ] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)
- [ ]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials
- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
- [ ] [Flaky Test
Runner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was
used on any tests changed
- [ ] Any UI touched in this PR is usable by keyboard only (learn more
about [keyboard accessibility](https://webaim.org/techniques/keyboard/))
- [ ] Any UI touched in this PR does not create any new axe failures
(run axe in browser:
[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),
[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))
- [ ] If a plugin configuration key changed, check if it needs to be
allowlisted in the cloud and added to the [docker
list](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker)
- [ ] This renders correctly on smaller devices using a responsive
layout. (You can test this [in your
browser](https://www.browserstack.com/guide/responsive-testing-on-local-server))
- [ ] This was checked for [cross-browser
compatibility](https://www.elastic.co/support/matrix#matrix_browsers)
  • Loading branch information
joemcelroy committed May 22, 2024
1 parent e8c82e2 commit 51f9eed
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 48 deletions.
154 changes: 106 additions & 48 deletions x-pack/plugins/search_playground/server/lib/conversational_chain.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,30 @@ import type { Client } from '@elastic/elasticsearch';
import { createAssist as Assist } from '../utils/assist';
import { ConversationalChain } from './conversational_chain';
import { FakeListChatModel } from '@langchain/core/utils/testing';
import { FakeListLLM } from 'langchain/llms/fake';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { Message } from 'ai';

describe('conversational chain', () => {
const createTestChain = async (
responses: string[],
chat: Message[],
expectedFinalAnswer: string,
expectedDocs: any,
expectedTokens: any,
expectedSearchRequest: any,
contentField: Record<string, string> = { index: 'field', website: 'body_content' }
) => {
const createTestChain = async ({
responses,
chat,
expectedFinalAnswer,
expectedDocs,
expectedTokens,
expectedSearchRequest,
contentField = { index: 'field', website: 'body_content' },
isChatModel = true,
}: {
responses: string[];
chat: Message[];
expectedFinalAnswer: string;
expectedDocs: any;
expectedTokens: any;
expectedSearchRequest: any;
contentField?: Record<string, string>;
isChatModel?: boolean;
}) => {
const searchMock = jest.fn().mockImplementation(() => {
return {
hits: {
Expand Down Expand Up @@ -54,9 +65,11 @@ describe('conversational chain', () => {
},
};

const llm = new FakeListChatModel({
responses,
});
const llm = isChatModel
? new FakeListChatModel({
responses,
})
: new FakeListLLM({ responses });

const aiClient = Assist({
es_client: mockElasticsearchClient as unknown as Client,
Expand Down Expand Up @@ -118,17 +131,17 @@ describe('conversational chain', () => {
};

it('should be able to create a conversational chain', async () => {
await createTestChain(
['the final answer'],
[
await createTestChain({
responses: ['the final answer'],
chat: [
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
],
'the final answer',
[
expectedFinalAnswer: 'the final answer',
expectedDocs: [
{
documents: [
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
Expand All @@ -137,32 +150,32 @@ describe('conversational chain', () => {
type: 'retrieved_docs',
},
],
[
expectedTokens: [
{ type: 'context_token_count', count: 15 },
{ type: 'prompt_token_count', count: 5 },
],
[
expectedSearchRequest: [
{
method: 'POST',
path: '/index,website/_search',
body: { query: { match: { field: 'what is the work from home policy?' } }, size: 3 },
},
]
);
],
});
});

it('should be able to create a conversational chain with nested field', async () => {
await createTestChain(
['the final answer'],
[
await createTestChain({
responses: ['the final answer'],
chat: [
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
],
'the final answer',
[
expectedFinalAnswer: 'the final answer',
expectedDocs: [
{
documents: [
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
Expand All @@ -171,25 +184,25 @@ describe('conversational chain', () => {
type: 'retrieved_docs',
},
],
[
expectedTokens: [
{ type: 'context_token_count', count: 15 },
{ type: 'prompt_token_count', count: 5 },
],
[
expectedSearchRequest: [
{
method: 'POST',
path: '/index,website/_search',
body: { query: { match: { field: 'what is the work from home policy?' } }, size: 3 },
},
],
{ index: 'field', website: 'metadata.source' }
);
contentField: { index: 'field', website: 'metadata.source' },
});
});

it('asking with chat history should re-write the question', async () => {
await createTestChain(
['rewrite the question', 'the final answer'],
[
await createTestChain({
responses: ['rewrite the question', 'the final answer'],
chat: [
{
id: '1',
role: 'user',
Expand All @@ -206,8 +219,8 @@ describe('conversational chain', () => {
content: 'what is the work from home policy?',
},
],
'the final answer',
[
expectedFinalAnswer: 'the final answer',
expectedDocs: [
{
documents: [
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
Expand All @@ -216,24 +229,24 @@ describe('conversational chain', () => {
type: 'retrieved_docs',
},
],
[
expectedTokens: [
{ type: 'context_token_count', count: 15 },
{ type: 'prompt_token_count', count: 5 },
],
[
expectedSearchRequest: [
{
method: 'POST',
path: '/index,website/_search',
body: { query: { match: { field: 'rewrite the question' } }, size: 3 },
},
]
);
],
});
});

it('should cope with quotes in the query', async () => {
await createTestChain(
['rewrite "the" question', 'the final answer'],
[
await createTestChain({
responses: ['rewrite "the" question', 'the final answer'],
chat: [
{
id: '1',
role: 'user',
Expand All @@ -250,8 +263,8 @@ describe('conversational chain', () => {
content: 'what is the work from home policy?',
},
],
'the final answer',
[
expectedFinalAnswer: 'the final answer',
expectedDocs: [
{
documents: [
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
Expand All @@ -260,17 +273,62 @@ describe('conversational chain', () => {
type: 'retrieved_docs',
},
],
[
expectedTokens: [
{ type: 'context_token_count', count: 15 },
{ type: 'prompt_token_count', count: 5 },
],
[
expectedSearchRequest: [
{
method: 'POST',
path: '/index,website/_search',
body: { query: { match: { field: 'rewrite "the" question' } }, size: 3 },
},
]
);
],
});
});

it('should work with an LLM based model', async () => {
await createTestChain({
responses: ['rewrite "the" question', 'the final answer'],
chat: [
{
id: '1',
role: 'user',
content: 'what is the work from home policy?',
},
{
id: '2',
role: 'assistant',
content: 'the final answer',
},
{
id: '3',
role: 'user',
content: 'what is the work from home policy?',
},
],
expectedFinalAnswer: 'the final answer',
expectedDocs: [
{
documents: [
{ metadata: { _id: '1', _index: 'index' }, pageContent: 'value' },
{ metadata: { _id: '1', _index: 'website' }, pageContent: 'value2' },
],
type: 'retrieved_docs',
},
],
expectedTokens: [
{ type: 'context_token_count', count: 15 },
{ type: 'prompt_token_count', count: 7 },
],
expectedSearchRequest: [
{
method: 'POST',
path: '/index,website/_search',
body: { query: { match: { field: 'rewrite "the" question' } }, size: 3 },
},
],
isChatModel: false,
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class ConversationalChainFn {
{
callbacks: [
{
// callback for chat based models (OpenAI)
handleChatModelStart(
llm,
msg: BaseMessage[][],
Expand All @@ -166,6 +167,15 @@ class ConversationalChainFn {
});
}
},
// callback for prompt based models (Bedrock uses ActionsClientLlm)
handleLLMStart(llm, input, runId, parentRunId, extraParams, tags, metadata) {
if (metadata?.type === 'question_answer_qa') {
data.appendMessageAnnotation({
type: 'prompt_token_count',
count: getTokenEstimate(input[0]),
});
}
},
handleRetrieverEnd(documents) {
retrievedDocs.push(...documents);
data.appendMessageAnnotation({
Expand Down

0 comments on commit 51f9eed

Please sign in to comment.