Skip to content

Commit

Permalink
Merge pull request #53 from dev-jpnobrega/feature/addChatHistory
Browse files Browse the repository at this point in the history
- adjust README
  • Loading branch information
dev-jpnobrega committed Feb 23, 2024
2 parents 71401c1 + fd25d41 commit 386cd54
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 70 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ The documents found are used for the context of the Agent.
systemMesssage: '<a message that will specialize your agent>',
chatConfig: {
temperature: 0,
}
},
llmConfig: {
type: '<cloud-provider-llm-service>', // Check availability at <link>
model: '<llm-model>',
Expand Down
60 changes: 30 additions & 30 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ import {
import { nanoid } from 'ai';
import { interpolate } from './helpers/string.helpers';
import { ChainService, IChainService } from './services/chain';
import ChatHistoryFactory from './services/chat-history';
import { ChatHistoryFactory, IChatHistory } from './services/chat-history';
import LLMFactory from './services/llm';
import VectorStoreFactory from './services/vector-store';
import { BaseMessage } from 'langchain/schema';

const EVENTS_NAME = {
onMessage: 'onMessage',
Expand All @@ -33,6 +34,9 @@ class Agent extends AgentBaseCommand implements IAgent {
private _vectorService: VectorStore;

private _chainService: IChainService;


private _chatHistory: IChatHistory;
private _bufferMemory: BufferMemory;
private _logger: Console;
private _settings: IAgentConfig;
Expand Down Expand Up @@ -61,28 +65,15 @@ class Agent extends AgentBaseCommand implements IAgent {
private async buildHistory(
userSessionId: string,
settings: IDatabaseConfig
): Promise<BufferMemory> {
if (this._bufferMemory && !settings) return this._bufferMemory;

if (!this._bufferMemory && !settings) {
this._bufferMemory = new BufferMemory({
returnMessages: true,
memoryKey: 'chat_history',
});

return this._bufferMemory;
}
): Promise<IChatHistory> {
if (this._chatHistory) return this._chatHistory;

this._bufferMemory = new BufferMemory({
returnMessages: true,
memoryKey: 'chat_history',
chatHistory: await ChatHistoryFactory.create({
...settings,
sessionId: userSessionId || nanoid(), // TODO
}),
});
this._chatHistory = await ChatHistoryFactory.create({
...settings,
sessionId: userSessionId || nanoid(), // TODO
})

return this._bufferMemory;
return this._chatHistory;
}

private async buildRelevantDocs(
Expand Down Expand Up @@ -115,13 +106,11 @@ class Agent extends AgentBaseCommand implements IAgent {
const { question, chatThreadID } = args;

try {
const memoryChat = await this.buildHistory(
const chatHistory = await this.buildHistory(
chatThreadID,
this._settings.dbHistoryConfig
);

memoryChat.chatHistory?.addUserMessage(question);

const { relevantDocs, referenciesDocs } = await this.buildRelevantDocs(
args,
this._settings.vectorStoreConfig
Expand All @@ -130,21 +119,23 @@ class Agent extends AgentBaseCommand implements IAgent {
const chain = await this._chainService.build(
this._llm,
question,
memoryChat
chatHistory.getBufferMemory(),
);
const chat_history = await memoryChat.chatHistory?.getMessages();

const chatMessages = await chatHistory.getMessages();

const result = await chain.call({
referencies: referenciesDocs,
input_documents: relevantDocs,
query: question,
question: question,
chat_history: chat_history?.slice(
-(this._settings?.dbHistoryConfig?.limit || 5)
),
chat_history: chatMessages,
format_chat_messages: chatHistory.getFormatedMessages(chatMessages),
user_prompt: this._settings.systemMesssage,
});

await memoryChat.chatHistory?.addAIChatMessage(result?.text);
await chatHistory.addUserMessage(question);
await chatHistory.addAIChatMessage(result?.text);

this.emit(EVENTS_NAME.onMessage, result?.text);

Expand All @@ -157,6 +148,15 @@ class Agent extends AgentBaseCommand implements IAgent {
}
}

getMessageFormat(messages: BaseMessage[]): string {
const cut = messages
.slice(-(this._settings?.dbHistoryConfig?.limit || 5));

const formated = cut.map((message) => `${message._getType().toUpperCase()}: ${message.content}`).join('\n');

return formated;
}

execute(args: any): Promise<void> {
throw new Error(args);
}
Expand Down
15 changes: 15 additions & 0 deletions src/services/chain/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ class ChainService {

builtMessage += '\n';
builtMessage += `
Given the user prompt and conversation log, the document context, the API output, and the following database output, formulate a response from a knowledge base.\n
You must follow the following rules and priorities when generating and responding:\n
- Always prioritize user prompt over conversation record.\n
- Ignore any conversation logs that are not directly related to the user prompt.\n
- Only try to answer if a question is asked.\n
- The question must be a single sentence.\n
- You must remove any punctuation from the question.\n
- You must remove any words that are not relevant to the question.\n
- If you are unable to formulate a question, respond in a friendly manner so the user can rephrase the question.\n\n
USER PROMPT: {user_prompt}\n
--------------------------------------
CHAT HISTORY: {format_chat_messages}\n
--------------------------------------
Context found in documents: {summaries}\n
--------------------------------------
Expand Down Expand Up @@ -132,6 +145,8 @@ class ChainService {
'input_documents',
'question',
'chat_history',
'format_chat_messages',
'user_prompt'
],
verbose: this._settings.debug || false,
memory: memoryChat,
Expand Down
81 changes: 57 additions & 24 deletions src/services/chain/openapi-base-chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ export class OpenApiBaseChain extends BaseChain {
readonly inputKey = 'query';
readonly outputKey = 'openAPIResult';
private _input: OpenApiBaseChainInput;
private _logger: Console;

constructor(input: OpenApiBaseChainInput) {
super(input);
this._input = input;
this._logger = console;
}

get inputKeys(): string[] {
Expand All @@ -38,19 +40,23 @@ export class OpenApiBaseChain extends BaseChain {
}

private getOpenApiPrompt(): string {
return `You are an AI with expertise in OpenAPI and Swagger.\n
Always answer the question in the language in which the question was asked.\n
- Always respond with the URL;\n
- Never put information or explanations in the answer;\n
${this._input.customizeSystemMessage || ''}
-------------------------------------------\n
SCHEMA: {schema}\n
-------------------------------------------\n
CHAT HISTORY: {chat_history}\n
-------------------------------------------\n
QUESTION: {question}\n
------------------------------------------\n
API ANSWER:`;
return `
You are an AI with expertise in OpenAPI and Swagger.\n
You should follow the following rules when generating and answer:\n
- Only execute the request on the service if the question is not in CHAT HISTORY, if the question has already been answered, use the same answer and do not make a request on the service.
- Only attempt to answer if a question was posed.\n
- Always answer the question in the language in which the question was asked.\n\n
-------------------------------------------\n
USER PROMPT: {user_prompt}\n
-------------------------------------------\n
SCHEMA: {schema}\n
-------------------------------------------\n
CHAT HISTORY: {format_chat_messages}\n
-------------------------------------------\n
QUESTION: {question}\n
------------------------------------------\n
API ANSWER:
`;
}

private buildPromptTemplate(systemMessages: string): BasePromptTemplate {
Expand All @@ -66,12 +72,27 @@ export class OpenApiBaseChain extends BaseChain {
return CHAT_COMBINE_PROMPT;
}

private tryParseText(text: string): string {
if (text.includes('No function_call in message')) {
try {
const txtSplitJson = text.split('No function_call in message ')[1];
const txtJson = JSON.parse(txtSplitJson);

return txtJson[0]?.text;
} catch (error) {
return `Sorry, I could not find the answer to your question.`;
}
}

return text;
}

async _call(
values: ChainValues,
runManager?: CallbackManagerForChainRun
): Promise<ChainValues> {
console.log('Values: ', values);
console.log('OPENAPI Input: ', values[this.inputKey]);
this._logger.log('Values: ', values);
this._logger.log('OPENAPI Input: ', values[this.inputKey]);

const question = values[this.inputKey];
const schema = this._input.spec;
Expand All @@ -83,15 +104,27 @@ export class OpenApiBaseChain extends BaseChain {
verbose: true,
});

const answer = await chain.invoke({
question,
schema,
chat_history: values?.chat_history,
});

console.log('OPENAPI Resposta: ', answer);

return { [this.outputKey]: answer?.response };
let answer:string = '';

try {
const rs = await chain.invoke({
question,
schema,
chat_history: values?.chat_history,
format_chat_messages: values?.format_chat_messages,
user_prompt: this._input.customizeSystemMessage || '',
});

this._logger.log('OPENAPI Resposta: ', answer);

answer = rs?.response;
} catch (error) {
this._logger.error('OPENAPI Error: ', error);

answer = this.tryParseText(error?.message);
} finally {
return { [this.outputKey]: answer };
}
}

_chainType(): string {
Expand Down
16 changes: 9 additions & 7 deletions src/services/chain/sql-database-chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@ export default class SqlDatabaseChain extends BaseChain {
Your response must only be a valid SQL query, based on the schema provided.\n
-------------------------------------------\n
Here are some important observations for generating the query:\n
${this.customMessage}\n
- Only execute the request on the service if the question is not in CHAT HISTORY, if the question has already been answered, use the same answer and do not make a query on the database.\n
{user_prompt}\n
-------------------------------------------\n
SCHEMA: {schema}\n
-------------------------------------------\n
CHAT HISTORY: {chat_history}\n
CHAT HISTORY: {format_chat_messages}\n
-------------------------------------------\n
QUESTION: {question}\n
------------------------------------------\n
Expand Down Expand Up @@ -123,8 +124,7 @@ export default class SqlDatabaseChain extends BaseChain {

return sqlBlock;
}

throw new Error(MESSAGES_ERRORS.dataEmpty);
return;
}

// TODO: check implementation for big data
Expand Down Expand Up @@ -173,6 +173,8 @@ export default class SqlDatabaseChain extends BaseChain {
schema: () => table_schema,
question: (input: { question: string }) => input.question,
chat_history: () => values?.chat_history,
format_chat_messages: () => values?.format_chat_messages,
user_prompt: () => this.customMessage,
},
this.buildPromptTemplate(this.getSQLPrompt()),
this.llm.bind({ stop: ['\nSQLResult:'] }),
Expand All @@ -190,12 +192,12 @@ export default class SqlDatabaseChain extends BaseChain {
question: (input) => input.question,
query: (input) => input.query,
response: async (input) => {
const sql = input.query.content;
const text = input.query.content;

try {
const sqlParserd = this.parserSQL(sql);
const sqlParserd = this.parserSQL(text);

if (!sqlParserd) return null;
if (!sqlParserd) return text;

console.log(`SQL`, sqlParserd);

Expand Down
30 changes: 26 additions & 4 deletions src/services/chat-history/index.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,39 @@

import { IDatabaseConfig } from '../../interface/agent.interface';
import { BaseChatMessageHistory } from 'langchain/schema';
import { BaseChatMessageHistory, BaseMessage } from 'langchain/schema';

import RedisChatHistory from './redis-chat-history';
import { BufferMemory } from 'langchain/memory';
import MemoryChatHistory from './memory-chat-history';

interface IChatHistory {
addUserMessage(message: string): Promise<void>;
addAIChatMessage(message: string): Promise<void>;
getMessages(): Promise<BaseMessage[]>;
getFormatedMessages(messages: BaseMessage[]): string;
clear(): Promise<void>;
getChatHistory(): BaseChatMessageHistory;
getBufferMemory(): BufferMemory;
}

const Services = {
redis: RedisChatHistory,
memory: MemoryChatHistory,
} as any;

class ChatHistoryFactory {
public static async create(settings: IDatabaseConfig): Promise<BaseChatMessageHistory> {
return await new Services[settings.type](settings).build();
public static async create(settings: IDatabaseConfig): Promise<IChatHistory> {
const service = new Services[settings?.type](settings);

if (!service) {
return await new MemoryChatHistory(settings).build();
}

return await service.build();
}
}

export default ChatHistoryFactory;
export {
IChatHistory,
ChatHistoryFactory,
};
Loading

0 comments on commit 386cd54

Please sign in to comment.