Skip to content

Commit

Permalink
Merge pull request #34 from dev-jpnobrega/hotfix/sql-chain-adjusts
Browse files Browse the repository at this point in the history
-up version
  • Loading branch information
dev-jpnobrega authored Nov 27, 2023
2 parents 72b5c81 + 967eec8 commit 8560766
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 44 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "ai-agent-enterprise",
"description": "AI Agent simplifies the implementation and use of generative AI with LangChain",
"version": "0.0.30",
"version": "0.0.31",
"main": "./build/index.js",
"types": "./build/index.d.ts",
"files": [
Expand Down
4 changes: 1 addition & 3 deletions src/services/chain/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class ChainService implements IChainService {
{sqlResult}\n
--------------------------------------
Query executed:
{sqlCommand}\n
{sqlQuery}\n
`;
}

Expand Down Expand Up @@ -91,8 +91,6 @@ class ChainService implements IChainService {
private async buildChains(llm: BaseChatModel, ...args: any): Promise<BaseChain[]> {
const chains = this.checkEnabledChains(this._settings);

console.warn(`this._settings.systemMesssage`, this._settings.systemMesssage);

const chain = loadQAMapReduceChain(llm, {
combinePrompt: this.buildPromptTemplate(
this._settings.systemMesssage || SYSTEM_MESSAGE_DEFAULT,
Expand Down
26 changes: 12 additions & 14 deletions src/services/chain/sql-chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,19 @@ import { BaseChatModel } from 'langchain/chat_models/base';
import { PromptTemplate } from 'langchain/prompts';

const SYSTEM_MESSAGE_DEFAULT = `
Given an input question, first create a syntactically correct {dialect} query to be performed, then execute a query after observing the query results and return the answer.\n
Given an input question, first create a syntactically correct postgres query to be performed, then execute a query after observing the query results and return the answer.\n
Never query all columns in a table. You should only query the possible columns to answer the question. Enclose each column name in double quotation marks (") to denote the delimited identifiers.\n
Pay attention to only use the column names that you can see in the tables below. Be careful not to query columns that don't exist. Also, pay attention to which column is in which table.\n
\n
Use the following format:\n
Question: Ask here\n
SQLQuery: SQL query to be performed\n
SQLResult: Result of SQLQuery\n
Answer: Final answer here\n
\n
Use only the following tables:\n
{table_info}
\n
{input}
\n
SCHEMA: {schema}
------------
QUESTION: {question}
------------
SQL QUERY: {query}
------------
SQL RESPONSE: {response}
\n\n
`;

class SqlChain {
Expand Down Expand Up @@ -55,12 +53,12 @@ class SqlChain {
llm,
database,
outputKey: 'sqlResult',
sqlOutputKey: 'sqlCommand',
sqlOutputKey: 'sqlQuery',
prompt: new PromptTemplate({
inputVariables: ['input', 'chat_history', 'dialect', 'table_info'],
inputVariables: ['question', 'response', 'schema', 'query'],
template: systemTemplate,
}),
});
}, this._settings?.customizeSystemMessage);

return chainSQL;
}
Expand Down
64 changes: 38 additions & 26 deletions src/services/chain/sql-database-chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@ export default class SqlDatabaseChain extends BaseChain {
inputKey = "query";

outputKey = "result";

customMessage = '';

sqlOutputKey: string | undefined = undefined;

// Whether to return the result of querying the SQL table directly.
returnDirect = false;

constructor(fields: SqlDatabaseChainInput) {
constructor(fields: SqlDatabaseChainInput, customMessage?: string) {
super(fields);
this.llm = fields.llm;
this.database = fields.database;
Expand All @@ -64,66 +66,76 @@ export default class SqlDatabaseChain extends BaseChain {
this.outputKey = fields.outputKey ?? this.outputKey;
this.sqlOutputKey = fields.sqlOutputKey ?? this.sqlOutputKey;
this.prompt = fields.prompt;
this.customMessage = customMessage || '';
}

async _call(values: ChainValues, runManager?: CallbackManagerForChainRun): Promise<ChainValues> {
const question: string = values[this.inputKey];

const prompt =
PromptTemplate.fromTemplate(`Based on the provided SQL table schema below, write a SQL query that would answer the user's question.
getSQLPrompt(): PromptTemplate {
return PromptTemplate.fromTemplate(`Based on the provided SQL table schema below, write a SQL query that would answer the user's question.\n
-------------------------------------------
${this.customMessage}\n
------------
SCHEMA: {schema}
------------
QUESTION: {question}
------------
SQL QUERY:`);
}

async _call(values: ChainValues, runManager?: CallbackManagerForChainRun): Promise<ChainValues> {
const question: string = values[this.inputKey];
const table_schema = await this.database.getTableInfo();

const sqlQueryChain = RunnableSequence.from([
{
schema: async () => this.database.getTableInfo(),
schema: () => table_schema,
question: (input: { question: string }) => input.question,
},
prompt,
this.getSQLPrompt(),
this.llm.bind({ stop: ["\nSQLResult:"] })
]);

const responsePrompt =
PromptTemplate.fromTemplate(`Based on the table schema below, question, SQL query, and SQL response, write a natural language response:
------------
SCHEMA: {schema}
------------
QUESTION: {question}
------------
SQL QUERY: {query}
------------
SQL RESPONSE: {response}`);

const finalChain = RunnableSequence.from([
{
question: (input) => input.question,
query: sqlQueryChain,
},
{
schema: async () => this.database.getTableInfo(),
table_info: () => table_schema,
input: () => question,
schema: () => table_schema,
question: (input) => input.question,
query: (input) => input.query,
response: (input) => {
response: async (input) => {
const sql = input.query.content.toLowerCase();

console.log(`SQL`, sql);

if (sql.includes('select') && sql.includes('from')) {
return this.database.run(input.query);
try {
const queryResult = await this.database.run(input.query);

return queryResult;
} catch (error) {
console.error(error);

return '';
}
}

return null;
return '';
},
},
{
[this.outputKey]: responsePrompt.pipe(this.llm).pipe(new StringOutputParser()),
[this.sqlOutputKey]: (previousStepResult) => previousStepResult.query,
[this.outputKey]: this.prompt.pipe(this.llm).pipe(new StringOutputParser()),
[this.sqlOutputKey]: (previousStepResult) => {
return previousStepResult?.query?.content;
},
},
]);

return finalChain.invoke({ question });
const result = await finalChain.invoke({ question });

return result;
}

_chainType(): string {
Expand Down

0 comments on commit 8560766

Please sign in to comment.