-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
sql.ts
199 lines (167 loc) Β· 5.56 KB
/
sql.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
import { OpenAI } from "@langchain/openai";
import { Tool } from "@langchain/core/tools";
import { PromptTemplate } from "@langchain/core/prompts";
import { LLMChain } from "../chains/llm_chain.js";
import type { SqlDatabase } from "../sql_db.js";
import { SqlTable } from "../util/sql_utils.js";
/**
* Interface for SQL tools. It has a `db` property which is a SQL
* database.
*/
interface SqlTool {
db: SqlDatabase;
}
/**
* A tool for executing SQL queries. It takes a SQL database as a
* parameter and assigns it to the `db` property. The `_call` method is
* used to run the SQL query and return the result. If the query is
* incorrect, an error message is returned.
*/
export class QuerySqlTool extends Tool implements SqlTool {
static lc_name() {
return "QuerySqlTool";
}
name = "query-sql";
db: SqlDatabase;
constructor(db: SqlDatabase) {
super(...arguments);
this.db = db;
}
/** @ignore */
async _call(input: string) {
try {
return await this.db.run(input);
} catch (error) {
return `${error}`;
}
}
description = `Input to this tool is a detailed and correct SQL query, output is a result from the database.
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.`;
}
/**
* A tool for retrieving information about SQL tables. It takes a SQL
* database as a parameter and assigns it to the `db` property. The
* `_call` method is used to retrieve the schema and sample rows for the
* specified tables. If the tables do not exist, an error message is
* returned.
*/
export class InfoSqlTool extends Tool implements SqlTool {
static lc_name() {
return "InfoSqlTool";
}
name = "info-sql";
db: SqlDatabase;
constructor(db: SqlDatabase) {
super();
this.db = db;
}
/** @ignore */
async _call(input: string) {
try {
const tables = input.split(",").map((table) => table.trim());
return await this.db.getTableInfo(tables);
} catch (error) {
return `${error}`;
}
}
description = `Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list-tables-sql first!
Example Input: "table1, table2, table3.`;
}
/**
* A tool for listing all tables in a SQL database. It takes a SQL
* database as a parameter and assigns it to the `db` property. The
* `_call` method is used to return a comma-separated list of all tables
* in the database.
*/
export class ListTablesSqlTool extends Tool implements SqlTool {
static lc_name() {
return "ListTablesSqlTool";
}
name = "list-tables-sql";
db: SqlDatabase;
constructor(db: SqlDatabase) {
super();
this.db = db;
}
async _call(_: string) {
try {
let selectedTables: SqlTable[] = this.db.allTables;
if (this.db.includesTables.length > 0) {
selectedTables = selectedTables.filter((currentTable) =>
this.db.includesTables.includes(currentTable.tableName)
);
}
if (this.db.ignoreTables.length > 0) {
selectedTables = selectedTables.filter(
(currentTable) =>
!this.db.ignoreTables.includes(currentTable.tableName)
);
}
const tables = selectedTables.map((table: SqlTable) => table.tableName);
return tables.join(", ");
} catch (error) {
return `${error}`;
}
}
description = `Input is an empty string, output is a comma-separated list of tables in the database.`;
}
/**
* Arguments for the QueryCheckerTool class.
*/
type QueryCheckerToolArgs = {
llmChain?: LLMChain;
llm?: BaseLanguageModelInterface;
_chainType?: never;
};
/**
* A tool for checking SQL queries for common mistakes. It takes a
* LLMChain or QueryCheckerToolArgs as a parameter. The `_call` method is
* used to check the input query for common mistakes and returns a
* prediction.
*/
export class QueryCheckerTool extends Tool {
static lc_name() {
return "QueryCheckerTool";
}
name = "query-checker";
template = `
{query}
Double check the sqlite query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.`;
llmChain: LLMChain;
constructor(llmChainOrOptions?: LLMChain | QueryCheckerToolArgs) {
super();
if (typeof llmChainOrOptions?._chainType === "function") {
this.llmChain = llmChainOrOptions as LLMChain;
} else {
const options = llmChainOrOptions as QueryCheckerToolArgs;
if (options?.llmChain !== undefined) {
this.llmChain = options.llmChain;
} else {
const prompt = new PromptTemplate({
template: this.template,
inputVariables: ["query"],
});
const llm = options?.llm ?? new OpenAI({ temperature: 0 });
this.llmChain = new LLMChain({ llm, prompt });
}
}
}
/** @ignore */
async _call(input: string) {
return this.llmChain.predict({ query: input });
}
description = `Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with query-sql!`;
}