diff --git a/src/oss/langgraph/sql-agent.mdx b/src/oss/langgraph/sql-agent.mdx
index a9df4ed24..8dc3f38fe 100644
--- a/src/oss/langgraph/sql-agent.mdx
+++ b/src/oss/langgraph/sql-agent.mdx
@@ -15,7 +15,6 @@ import StableCalloutJS from '/snippets/stable-lg-callout-js.mdx';
:::
-:::python
In this tutorial we will build a custom agent that can answer questions about a SQL database using LangGraph.
LangChain offers built-in [agent](/oss/langchain/agents) implementations, implemented using [LangGraph](/oss/langgraph/overview) primitives. If deeper customization is required, agents can be implemented directly in LangGraph. This guide demonstrates an example implementation of a SQL agent. You can find a tutorial building a SQL agent using higher-level LangChain abstractions [here](/oss/langchain/sql-agent).
@@ -40,11 +39,26 @@ We will cover the following concepts:
### Installation
+ :::python
```bash pip
pip install langchain langgraph langchain-community
```
+ :::
+ :::js
+
+ ```bash npm
+ npm i langchain @langchain/core @langchain/classic @langchain/langgraph @langchain/openai typeorm sqlite3 zod
+ ```
+ ```bash yarn
+ yarn add langchain @langchain/core @langchain/classic @langchain/langgraph @langchain/openai typeorm sqlite3 zod
+ ```
+ ```bash pnpm
+ pnpm add langchain @langchain/core @langchain/classic @langchain/langgraph @langchain/openai typeorm sqlite3 zod
+ ```
+
+ :::
### LangSmith
Set up [LangSmith](https://smith.langchain.com) to inspect what is happening inside your chain or agent. Then set the following environment variables:
@@ -58,7 +72,12 @@ Set up [LangSmith](https://smith.langchain.com) to inspect what is happening ins
Select a model that supports [tool-calling](/oss/integrations/providers/overview):
+:::python
+:::
+:::js
+
+:::
The output shown in the examples below used OpenAI.
@@ -68,6 +87,7 @@ You will be creating a [SQLite database](https://www.sqlitetutorial.net/sqlite-s
For convenience, we have hosted the database (`Chinook.db`) on a public GCS bucket.
+:::python
```python
import requests, pathlib
@@ -101,9 +121,57 @@ Dialect: sqlite
Available tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
Sample output: [(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')]
```
+:::
+:::js
+```typescript
+import fs from "node:fs/promises";
+import path from "node:path";
+
+const url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db";
+const localPath = path.resolve("Chinook.db");
+
+async function resolveDbPath() {
+ const exists = await fs.access(localPath).then(() => true).catch(() => false);
+ if (exists) {
+ console.log(`${localPath} already exists, skipping download.`);
+ return localPath;
+ }
+ const resp = await fetch(url);
+ if (!resp.ok) throw new Error(`Failed to download DB. Status code: ${resp.status}`);
+ const buf = Buffer.from(await resp.arrayBuffer());
+ await fs.writeFile(localPath, buf);
+ console.log(`File downloaded and saved as ${localPath}`);
+ return localPath;
+}
+```
+
+We will use a handy SQL database wrapper available in the `@langchain/classic/sql_db` module to interact with the database. The wrapper provides a simple interface to execute SQL queries and fetch results:
+
+```typescript
+import { SqlDatabase } from "@langchain/classic/sql_db";
+import { DataSource } from "typeorm";
+
+const dbPath = await resolveDbPath();
+const datasource = new DataSource({ type: "sqlite", database: dbPath });
+const db = await SqlDatabase.fromDataSourceParams({ appDataSource: datasource });
+const dialect = db.appDataSourceOptions.type;
+
+console.log(`Dialect: ${dialect}`);
+const tableNames = db.allTables.map(t => t.tableName);
+console.log(`Available tables: ${tableNames.join(", ")}`);
+const sampleResults = await db.run("SELECT * FROM Artist LIMIT 5;");
+console.log(`Sample output: ${sampleResults}`);
+```
+```
+Dialect: sqlite
+Available tables: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
+Sample output: [{"ArtistId":1,"Name":"AC/DC"},{"ArtistId":2,"Name":"Accept"},{"ArtistId":3,"Name":"Aerosmith"},{"ArtistId":4,"Name":"Alanis Morissette"},{"ArtistId":5,"Name":"Alice In Chains"}]
+```
+:::
## 3. Add tools for database interactions
+:::python
Use the `SQLDatabase` wrapper available in the `langchain_community` package to interact with the database. The wrapper provides a simple interface to execute SQL queries and fetch results:
```python
@@ -125,6 +193,75 @@ sql_db_list_tables: Input is an empty string, output is a comma-separated list o
sql_db_query_checker: Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!
```
+:::
+:::js
+We'll create custom tools to interact with the database:
+
+```typescript
+import { tool } from "@langchain/core/tools";
+import { z } from "zod";
+
+// Tool to list all tables
+const listTablesTool = tool(
+ async () => {
+ const tableNames = db.allTables.map(t => t.tableName);
+ return tableNames.join(", ");
+ },
+ {
+ name: "sql_db_list_tables",
+ description: "Input is an empty string, output is a comma-separated list of tables in the database.",
+ schema: z.object({}),
+ }
+);
+
+// Tool to get schema for specific tables
+const getSchemaTool = tool(
+ async ({ table_names }) => {
+ const tables = table_names.split(",").map(t => t.trim());
+ return await db.getTableInfo(tables);
+ },
+ {
+ name: "sql_db_schema",
+ description: "Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3",
+ schema: z.object({
+ table_names: z.string().describe("Comma-separated list of table names"),
+ }),
+ }
+);
+
+// Tool to execute SQL query
+const queryTool = tool(
+ async ({ query }) => {
+ try {
+ const result = await db.run(query);
+ return typeof result === "string" ? result : JSON.stringify(result);
+ } catch (error) {
+ return `Error: ${error.message}`;
+ }
+ },
+ {
+ name: "sql_db_query",
+ description: "Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again.",
+ schema: z.object({
+ query: z.string().describe("SQL query to execute"),
+ }),
+ }
+);
+
+const tools = [listTablesTool, getSchemaTool, queryTool];
+
+for (const tool of tools) {
+ console.log(`${tool.name}: ${tool.description}\n`);
+}
+```
+```
+sql_db_list_tables: Input is an empty string, output is a comma-separated list of tables in the database.
+
+sql_db_schema: Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3
+
+sql_db_query: Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again.
+```
+:::
## 4. Define application steps
@@ -137,6 +274,7 @@ We construct dedicated nodes for the following steps:
Putting these steps in dedicated nodes lets us (1) force tool-calls when needed, and (2) customize the prompts associated with each step.
+:::python
```python
from typing import Literal
@@ -245,11 +383,117 @@ def check_query(state: MessagesState):
return {"messages": [response]}
```
+:::
+:::js
+```typescript
+import { AIMessage, ToolMessage, SystemMessage, HumanMessage } from "@langchain/core/messages";
+import { ToolNode } from "@langchain/langgraph/prebuilt";
+import { MessagesAnnotation, StateGraph, START, END } from "@langchain/langgraph";
+
+// Create tool nodes for schema and query execution
+const getSchemaNode = new ToolNode([getSchemaTool]);
+const runQueryNode = new ToolNode([queryTool]);
+
+// Example: create a predetermined tool call
+async function listTables(state: typeof MessagesAnnotation.State) {
+ const toolCall = {
+ name: "sql_db_list_tables",
+ args: {},
+ id: "abc123",
+ type: "tool_call" as const,
+ };
+ const toolCallMessage = new AIMessage({
+ content: "",
+ tool_calls: [toolCall],
+ });
+
+ const toolMessage = await listTablesTool.invoke({});
+ const response = new AIMessage(`Available tables: ${toolMessage}`);
+
+ return { messages: [toolCallMessage, new ToolMessage({ content: toolMessage, tool_call_id: "abc123" }), response] };
+}
+
+// Example: force a model to create a tool call
+async function callGetSchema(state: typeof MessagesAnnotation.State) {
+ const llmWithTools = llm.bindTools([getSchemaTool], {
+ tool_choice: "any",
+ });
+ const response = await llmWithTools.invoke(state.messages);
+
+ return { messages: [response] };
+}
+
+const topK = 5;
+
+const generateQuerySystemPrompt = `
+You are an agent designed to interact with a SQL database.
+Given an input question, create a syntactically correct ${dialect}
+query to run, then look at the results of the query and return the answer. Unless
+the user specifies a specific number of examples they wish to obtain, always limit
+your query to at most ${topK} results.
+
+You can order the results by a relevant column to return the most interesting
+examples in the database. Never query for all the columns from a specific table,
+only ask for the relevant columns given the question.
+
+DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
+`;
+
+async function generateQuery(state: typeof MessagesAnnotation.State) {
+ const systemMessage = new SystemMessage(generateQuerySystemPrompt);
+ // We do not force a tool call here, to allow the model to
+ // respond naturally when it obtains the solution.
+ const llmWithTools = llm.bindTools([queryTool]);
+ const response = await llmWithTools.invoke([systemMessage, ...state.messages]);
+
+ return { messages: [response] };
+}
+
+const checkQuerySystemPrompt = `
+You are a SQL expert with a strong attention to detail.
+Double check the ${dialect} query for common mistakes, including:
+- Using NOT IN with NULL values
+- Using UNION when UNION ALL should have been used
+- Using BETWEEN for exclusive ranges
+- Data type mismatch in predicates
+- Properly quoting identifiers
+- Using the correct number of arguments for functions
+- Casting to the correct data type
+- Using the proper columns for joins
+
+If there are any of the above mistakes, rewrite the query. If there are no mistakes,
+just reproduce the original query.
+
+You will call the appropriate tool to execute the query after running this check.
+`;
+
+async function checkQuery(state: typeof MessagesAnnotation.State) {
+ const systemMessage = new SystemMessage(checkQuerySystemPrompt);
+
+ // Generate an artificial user message to check
+ const lastMessage = state.messages[state.messages.length - 1];
+ if (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0) {
+ throw new Error("No tool calls found in the last message");
+ }
+ const toolCall = lastMessage.tool_calls[0];
+ const userMessage = new HumanMessage(toolCall.args.query);
+ const llmWithTools = llm.bindTools([queryTool], {
+ tool_choice: "any",
+ });
+ const response = await llmWithTools.invoke([systemMessage, userMessage]);
+ // Preserve the original message ID
+ response.id = lastMessage.id;
+
+ return { messages: [response] };
+}
+```
+:::
## 5. Implement the agent
We can now assemble these steps into a workflow using the [Graph API](/oss/langgraph/graph-api). We define a [conditional edge](/oss/langgraph/graph-api#conditional-edges) at the query generation step that will route to the query checker if a query is generated, or end if there are no tool calls present, such that the LLM has delivered a response to the query.
+:::python
```python
def should_continue(state: MessagesState) -> Literal[END, "check_query"]:
messages = state["messages"]
@@ -288,6 +532,47 @@ from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeSt
display(Image(agent.get_graph().draw_mermaid_png()))
```
+:::
+:::js
+```typescript
+function shouldContinue(state: typeof MessagesAnnotation.State): "check_query" | typeof END {
+ const messages = state.messages;
+ const lastMessage = messages[messages.length - 1];
+ if (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0) {
+ return END;
+ } else {
+ return "check_query";
+ }
+}
+
+const builder = new StateGraph(MessagesAnnotation)
+ .addNode("list_tables", listTables)
+ .addNode("call_get_schema", callGetSchema)
+ .addNode("get_schema", getSchemaNode)
+ .addNode("generate_query", generateQuery)
+ .addNode("check_query", checkQuery)
+ .addNode("run_query", runQueryNode)
+ .addEdge(START, "list_tables")
+ .addEdge("list_tables", "call_get_schema")
+ .addEdge("call_get_schema", "get_schema")
+ .addEdge("get_schema", "generate_query")
+ .addConditionalEdges("generate_query", shouldContinue)
+ .addEdge("check_query", "run_query")
+ .addEdge("run_query", "generate_query");
+
+const agent = builder.compile();
+```
+We visualize the application below:
+```typescript
+import * as fs from "node:fs/promises";
+
+const drawableGraph = await agent.getGraphAsync();
+const image = await drawableGraph.drawMermaidPng();
+const imageBuffer = new Uint8Array(await image.arrayBuffer());
+
+await fs.writeFile("graph.png", imageBuffer);
+```
+:::
We can now invoke the graph:
+:::python
```python
question = "Which genre on average has the longest tracks?"
@@ -304,6 +590,24 @@ for step in agent.stream(
):
step["messages"][-1].pretty_print()
```
+:::
+:::js
+```typescript
+const question = "Which genre on average has the longest tracks?";
+
+const stream = await agent.stream(
+ { messages: [{ role: "user", content: question }] },
+ { streamMode: "values" }
+);
+
+for await (const step of stream) {
+ if (step.messages && step.messages.length > 0) {
+ const lastMessage = step.messages[step.messages.length - 1];
+ console.log(lastMessage.toFormattedString());
+ }
+}
+```
+:::
```
================================ Human Message =================================
@@ -379,9 +683,16 @@ Name: sql_db_query
The genre with the longest tracks on average is "Sci Fi & Fantasy," with an average track length of approximately 2,911,783 milliseconds. Other genres with relatively long tracks include "Science Fiction," "Drama," "TV Shows," and "Comedy."
```
+:::python
See [LangSmith trace](https://smith.langchain.com/public/94b8c9ac-12f7-4692-8706-836a1f30f1ea/r) for the above run.
+:::
+:::js
+
+See [LangSmith trace](https://smith.langchain.com/public/a6a96896-686a-4040-b9b5-28d701453d6f/r) for the above run.
+
+:::
## 6. Implement human-in-the-loop review
@@ -390,6 +701,8 @@ It can be prudent to check the agent's SQL queries before they are executed for
Here we leverage LangGraph's [human-in-the-loop](/oss/langgraph/interrupts) features to pause the run before executing a SQL query and wait for human review. Using LangGraph's [persistence layer](/oss/langgraph/persistence), we can pause the run indefinitely (or at least as long as the persistence layer is alive).
Let's wrap the `sql_db_query` tool in a node that receives human input. We can implement this using the [interrupt](/oss/langgraph/interrupts) function. Below, we allow for input to approve the tool call, edit its arguments, or provide user feedback.
+
+:::python
```python
from langchain_core.runnables import RunnableConfig
from langchain.tools import tool
@@ -423,11 +736,56 @@ def run_query_tool_with_interrupt(config: RunnableConfig, **tool_input):
return tool_response
```
+:::
+:::js
+```typescript
+import { RunnableConfig } from "@langchain/core/runnables";
+import { tool } from "@langchain/core/tools";
+import { interrupt } from "@langchain/langgraph";
+
+const queryToolWithInterrupt = tool(
+ async (input, config: RunnableConfig) => {
+ const request = {
+ action: queryTool.name,
+ args: input,
+ description: "Please review the tool call",
+ };
+ const response = interrupt([request]); // [!code highlight]
+ // approve the tool call
+ if (response.type === "accept") {
+ const toolResponse = await queryTool.invoke(input, config);
+ return toolResponse;
+ }
+ // update tool call args
+ else if (response.type === "edit") {
+ const editedInput = response.args.args;
+ const toolResponse = await queryTool.invoke(editedInput, config);
+ return toolResponse;
+ }
+ // respond to the LLM with user feedback
+ else if (response.type === "response") {
+ const userFeedback = response.args;
+ return userFeedback;
+ } else {
+ throw new Error(`Unsupported interrupt response type: ${response.type}`);
+ }
+ },
+ {
+ name: queryTool.name,
+ description: queryTool.description,
+ schema: queryTool.schema,
+ }
+);
+```
+:::
+
The above implementation follows the [tool interrupt example](/oss/langgraph/interrupts#configuring-interrupts) in the broader [human-in-the-loop](/oss/langgraph/interrupts) guide. Refer to that guide for details and alternatives.
Let's now re-assemble our graph. We will replace the programmatic check with human review. Note that we now include a [checkpointer](/oss/langgraph/persistence); this is required to pause and resume the run.
+
+:::python
```python
from langgraph.checkpoint.memory import InMemorySaver
@@ -459,7 +817,44 @@ builder.add_edge("run_query", "generate_query")
checkpointer = InMemorySaver() # [!code highlight]
agent = builder.compile(checkpointer=checkpointer) # [!code highlight]
```
+:::
+:::js
+```typescript
+import { MemorySaver } from "@langchain/langgraph";
+
+function shouldContinueWithHuman(state: typeof MessagesAnnotation.State): "run_query" | typeof END {
+ const messages = state.messages;
+ const lastMessage = messages[messages.length - 1];
+ if (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0) {
+ return END;
+ } else {
+ return "run_query";
+ }
+}
+
+const runQueryNodeWithInterrupt = new ToolNode([queryToolWithInterrupt]);
+
+const builderWithHuman = new StateGraph(MessagesAnnotation)
+ .addNode("list_tables", listTables)
+ .addNode("call_get_schema", callGetSchema)
+ .addNode("get_schema", getSchemaNode)
+ .addNode("generate_query", generateQuery)
+ .addNode("run_query", runQueryNodeWithInterrupt)
+ .addEdge(START, "list_tables")
+ .addEdge("list_tables", "call_get_schema")
+ .addEdge("call_get_schema", "get_schema")
+ .addEdge("get_schema", "generate_query")
+ .addConditionalEdges("generate_query", shouldContinueWithHuman)
+ .addEdge("run_query", "generate_query");
+
+const checkpointer = new MemorySaver(); // [!code highlight]
+const agentWithHuman = builderWithHuman.compile({ checkpointer }); // [!code highlight]
+```
+:::
+
We can invoke the graph as before. This time, execution is interrupted:
+
+:::python
```python
import json
@@ -482,6 +877,34 @@ for step in agent.stream(
else:
pass
```
+:::
+:::js
+```typescript
+const config = { configurable: { thread_id: "1" } };
+
+const question = "Which genre on average has the longest tracks?";
+
+const stream = await agentWithHuman.stream(
+ { messages: [{ role: "user", content: question }] },
+ { ...config, streamMode: "values" }
+);
+
+for await (const step of stream) {
+ if (step.messages && step.messages.length > 0) {
+ const lastMessage = step.messages[step.messages.length - 1];
+ console.log(lastMessage.toFormattedString());
+ }
+}
+
+// Check for interrupts
+const state = await agentWithHuman.getState(config);
+if (state.next.length > 0) {
+ console.log("\nINTERRUPTED:");
+ console.log(JSON.stringify(state.tasks[0].interrupts[0], null, 2));
+}
+```
+:::
+
```
...
@@ -495,6 +918,8 @@ INTERRUPTED:
}
```
We can accept or edit the tool call using [Command](/oss/langgraph/use-graph-api#combine-control-flow-and-state-updates-with-command):
+
+:::python
```python
from langgraph.types import Command
@@ -515,6 +940,26 @@ for step in agent.stream(
else:
pass
```
+:::
+:::js
+```typescript
+import { Command } from "@langchain/langgraph";
+
+const resumeStream = await agentWithHuman.stream(
+ new Command({ resume: { type: "accept" } }),
+ // new Command({ resume: { type: "edit", args: { query: "..." } } }),
+ { ...config, streamMode: "values" }
+);
+
+for await (const step of resumeStream) {
+ if (step.messages && step.messages.length > 0) {
+ const lastMessage = step.messages[step.messages.length - 1];
+ console.log(lastMessage.toFormattedString());
+ }
+}
+```
+:::
+
```
================================== Ai Message ==================================
Tool Calls:
@@ -535,9 +980,3 @@ Refer to the [human-in-the-loop guide](/oss/langgraph/interrupts) for details.
## Next steps
Check out the [Evaluate a graph](/langsmith/evaluate-graph) guide for evaluating LangGraph applications, including SQL agents like this one, using LangSmith.
-:::
-:::js
-## Under construction
-
-This tutorial has not yet been implemented in Typescript. Refer to the LangChain [SQL agent guide](/oss/langchain/sql-agent) for a reference implementation.
-:::