-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
sql.ts
89 lines (82 loc) · 2.52 KB
/
sql.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
import type { ToolInterface } from "@langchain/core/tools";
import {
InfoSqlTool,
ListTablesSqlTool,
QueryCheckerTool,
QuerySqlTool,
} from "../../../tools/sql.js";
import { Toolkit } from "../base.js";
import { SQL_PREFIX, SQL_SUFFIX } from "./prompt.js";
import { renderTemplate } from "../../../prompts/template.js";
import { LLMChain } from "../../../chains/llm_chain.js";
import { ZeroShotAgent, ZeroShotCreatePromptArgs } from "../../mrkl/index.js";
import { AgentExecutor } from "../../executor.js";
import { SqlDatabase } from "../../../sql_db.js";
/**
* Interface that extends ZeroShotCreatePromptArgs and adds an optional
* topK parameter for specifying the number of results to return.
*/
export interface SqlCreatePromptArgs extends ZeroShotCreatePromptArgs {
/** Number of results to return. */
topK?: number;
}
/**
* Class that represents a toolkit for working with SQL databases. It
* initializes SQL tools based on the provided SQL database.
* @example
* ```typescript
* const model = new ChatOpenAI({});
* const toolkit = new SqlToolkit(sqlDb, model);
* const executor = createSqlAgent(model, toolkit);
* const result = await executor.invoke({ input: 'List the total sales per country. Which country's customers spent the most?' });
* console.log(`Got output ${result.output}`);
* ```
*/
export class SqlToolkit extends Toolkit {
tools: ToolInterface[];
db: SqlDatabase;
dialect = "sqlite";
constructor(db: SqlDatabase, llm?: BaseLanguageModelInterface) {
super();
this.db = db;
this.tools = [
new QuerySqlTool(db),
new InfoSqlTool(db),
new ListTablesSqlTool(db),
new QueryCheckerTool({ llm }),
];
}
}
export function createSqlAgent(
llm: BaseLanguageModelInterface,
toolkit: SqlToolkit,
args?: SqlCreatePromptArgs
) {
const {
prefix = SQL_PREFIX,
suffix = SQL_SUFFIX,
inputVariables = ["input", "agent_scratchpad"],
topK = 10,
} = args ?? {};
const { tools } = toolkit;
const formattedPrefix = renderTemplate(prefix, "f-string", {
dialect: toolkit.dialect,
top_k: topK,
});
const prompt = ZeroShotAgent.createPrompt(tools, {
prefix: formattedPrefix,
suffix,
inputVariables,
});
const chain = new LLMChain({ prompt, llm });
const agent = new ZeroShotAgent({
llmChain: chain,
allowedTools: tools.map((t) => t.name),
});
return AgentExecutor.fromAgentAndTools({
agent,
tools,
returnIntermediateSteps: true,
});
}