diff --git a/src/oss/langgraph/sql-agent.mdx b/src/oss/langgraph/sql-agent.mdx index a9df4ed24..8dc3f38fe 100644 --- a/src/oss/langgraph/sql-agent.mdx +++ b/src/oss/langgraph/sql-agent.mdx @@ -15,7 +15,6 @@ import StableCalloutJS from '/snippets/stable-lg-callout-js.mdx'; ::: -:::python In this tutorial we will build a custom agent that can answer questions about a SQL database using LangGraph. LangChain offers built-in [agent](/oss/langchain/agents) implementations, implemented using [LangGraph](/oss/langgraph/overview) primitives. If deeper customization is required, agents can be implemented directly in LangGraph. This guide demonstrates an example implementation of a SQL agent. You can find a tutorial building a SQL agent using higher-level LangChain abstractions [here](/oss/langchain/sql-agent). @@ -40,11 +39,26 @@ We will cover the following concepts: ### Installation + :::python ```bash pip pip install langchain langgraph langchain-community ``` + ::: + :::js + + ```bash npm + npm i langchain @langchain/core @langchain/classic @langchain/langgraph @langchain/openai typeorm sqlite3 zod + ``` + ```bash yarn + yarn add langchain @langchain/core @langchain/classic @langchain/langgraph @langchain/openai typeorm sqlite3 zod + ``` + ```bash pnpm + pnpm add langchain @langchain/core @langchain/classic @langchain/langgraph @langchain/openai typeorm sqlite3 zod + ``` + + ::: ### LangSmith Set up [LangSmith](https://smith.langchain.com) to inspect what is happening inside your chain or agent. Then set the following environment variables: @@ -58,7 +72,12 @@ Set up [LangSmith](https://smith.langchain.com) to inspect what is happening ins Select a model that supports [tool-calling](/oss/integrations/providers/overview): +:::python +::: +:::js + +::: The output shown in the examples below used OpenAI. @@ -68,6 +87,7 @@ You will be creating a [SQLite database](https://www.sqlitetutorial.net/sqlite-s For convenience, we have hosted the database (`Chinook.db`) on a public GCS bucket. +:::python ```python import requests, pathlib @@ -101,9 +121,57 @@ Dialect: sqlite Available tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track'] Sample output: [(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')] ``` +::: +:::js +```typescript +import fs from "node:fs/promises"; +import path from "node:path"; + +const url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"; +const localPath = path.resolve("Chinook.db"); + +async function resolveDbPath() { + const exists = await fs.access(localPath).then(() => true).catch(() => false); + if (exists) { + console.log(`${localPath} already exists, skipping download.`); + return localPath; + } + const resp = await fetch(url); + if (!resp.ok) throw new Error(`Failed to download DB. Status code: ${resp.status}`); + const buf = Buffer.from(await resp.arrayBuffer()); + await fs.writeFile(localPath, buf); + console.log(`File downloaded and saved as ${localPath}`); + return localPath; +} +``` + +We will use a handy SQL database wrapper available in the `@langchain/classic/sql_db` module to interact with the database. The wrapper provides a simple interface to execute SQL queries and fetch results: + +```typescript +import { SqlDatabase } from "@langchain/classic/sql_db"; +import { DataSource } from "typeorm"; + +const dbPath = await resolveDbPath(); +const datasource = new DataSource({ type: "sqlite", database: dbPath }); +const db = await SqlDatabase.fromDataSourceParams({ appDataSource: datasource }); +const dialect = db.appDataSourceOptions.type; + +console.log(`Dialect: ${dialect}`); +const tableNames = db.allTables.map(t => t.tableName); +console.log(`Available tables: ${tableNames.join(", ")}`); +const sampleResults = await db.run("SELECT * FROM Artist LIMIT 5;"); +console.log(`Sample output: ${sampleResults}`); +``` +``` +Dialect: sqlite +Available tables: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track +Sample output: [{"ArtistId":1,"Name":"AC/DC"},{"ArtistId":2,"Name":"Accept"},{"ArtistId":3,"Name":"Aerosmith"},{"ArtistId":4,"Name":"Alanis Morissette"},{"ArtistId":5,"Name":"Alice In Chains"}] +``` +::: ## 3. Add tools for database interactions +:::python Use the `SQLDatabase` wrapper available in the `langchain_community` package to interact with the database. The wrapper provides a simple interface to execute SQL queries and fetch results: ```python @@ -125,6 +193,75 @@ sql_db_list_tables: Input is an empty string, output is a comma-separated list o sql_db_query_checker: Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query! ``` +::: +:::js +We'll create custom tools to interact with the database: + +```typescript +import { tool } from "@langchain/core/tools"; +import { z } from "zod"; + +// Tool to list all tables +const listTablesTool = tool( + async () => { + const tableNames = db.allTables.map(t => t.tableName); + return tableNames.join(", "); + }, + { + name: "sql_db_list_tables", + description: "Input is an empty string, output is a comma-separated list of tables in the database.", + schema: z.object({}), + } +); + +// Tool to get schema for specific tables +const getSchemaTool = tool( + async ({ table_names }) => { + const tables = table_names.split(",").map(t => t.trim()); + return await db.getTableInfo(tables); + }, + { + name: "sql_db_schema", + description: "Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3", + schema: z.object({ + table_names: z.string().describe("Comma-separated list of table names"), + }), + } +); + +// Tool to execute SQL query +const queryTool = tool( + async ({ query }) => { + try { + const result = await db.run(query); + return typeof result === "string" ? result : JSON.stringify(result); + } catch (error) { + return `Error: ${error.message}`; + } + }, + { + name: "sql_db_query", + description: "Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again.", + schema: z.object({ + query: z.string().describe("SQL query to execute"), + }), + } +); + +const tools = [listTablesTool, getSchemaTool, queryTool]; + +for (const tool of tools) { + console.log(`${tool.name}: ${tool.description}\n`); +} +``` +``` +sql_db_list_tables: Input is an empty string, output is a comma-separated list of tables in the database. + +sql_db_schema: Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3 + +sql_db_query: Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. +``` +::: ## 4. Define application steps @@ -137,6 +274,7 @@ We construct dedicated nodes for the following steps: Putting these steps in dedicated nodes lets us (1) force tool-calls when needed, and (2) customize the prompts associated with each step. +:::python ```python from typing import Literal @@ -245,11 +383,117 @@ def check_query(state: MessagesState): return {"messages": [response]} ``` +::: +:::js +```typescript +import { AIMessage, ToolMessage, SystemMessage, HumanMessage } from "@langchain/core/messages"; +import { ToolNode } from "@langchain/langgraph/prebuilt"; +import { MessagesAnnotation, StateGraph, START, END } from "@langchain/langgraph"; + +// Create tool nodes for schema and query execution +const getSchemaNode = new ToolNode([getSchemaTool]); +const runQueryNode = new ToolNode([queryTool]); + +// Example: create a predetermined tool call +async function listTables(state: typeof MessagesAnnotation.State) { + const toolCall = { + name: "sql_db_list_tables", + args: {}, + id: "abc123", + type: "tool_call" as const, + }; + const toolCallMessage = new AIMessage({ + content: "", + tool_calls: [toolCall], + }); + + const toolMessage = await listTablesTool.invoke({}); + const response = new AIMessage(`Available tables: ${toolMessage}`); + + return { messages: [toolCallMessage, new ToolMessage({ content: toolMessage, tool_call_id: "abc123" }), response] }; +} + +// Example: force a model to create a tool call +async function callGetSchema(state: typeof MessagesAnnotation.State) { + const llmWithTools = llm.bindTools([getSchemaTool], { + tool_choice: "any", + }); + const response = await llmWithTools.invoke(state.messages); + + return { messages: [response] }; +} + +const topK = 5; + +const generateQuerySystemPrompt = ` +You are an agent designed to interact with a SQL database. +Given an input question, create a syntactically correct ${dialect} +query to run, then look at the results of the query and return the answer. Unless +the user specifies a specific number of examples they wish to obtain, always limit +your query to at most ${topK} results. + +You can order the results by a relevant column to return the most interesting +examples in the database. Never query for all the columns from a specific table, +only ask for the relevant columns given the question. + +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. +`; + +async function generateQuery(state: typeof MessagesAnnotation.State) { + const systemMessage = new SystemMessage(generateQuerySystemPrompt); + // We do not force a tool call here, to allow the model to + // respond naturally when it obtains the solution. + const llmWithTools = llm.bindTools([queryTool]); + const response = await llmWithTools.invoke([systemMessage, ...state.messages]); + + return { messages: [response] }; +} + +const checkQuerySystemPrompt = ` +You are a SQL expert with a strong attention to detail. +Double check the ${dialect} query for common mistakes, including: +- Using NOT IN with NULL values +- Using UNION when UNION ALL should have been used +- Using BETWEEN for exclusive ranges +- Data type mismatch in predicates +- Properly quoting identifiers +- Using the correct number of arguments for functions +- Casting to the correct data type +- Using the proper columns for joins + +If there are any of the above mistakes, rewrite the query. If there are no mistakes, +just reproduce the original query. + +You will call the appropriate tool to execute the query after running this check. +`; + +async function checkQuery(state: typeof MessagesAnnotation.State) { + const systemMessage = new SystemMessage(checkQuerySystemPrompt); + + // Generate an artificial user message to check + const lastMessage = state.messages[state.messages.length - 1]; + if (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0) { + throw new Error("No tool calls found in the last message"); + } + const toolCall = lastMessage.tool_calls[0]; + const userMessage = new HumanMessage(toolCall.args.query); + const llmWithTools = llm.bindTools([queryTool], { + tool_choice: "any", + }); + const response = await llmWithTools.invoke([systemMessage, userMessage]); + // Preserve the original message ID + response.id = lastMessage.id; + + return { messages: [response] }; +} +``` +::: ## 5. Implement the agent We can now assemble these steps into a workflow using the [Graph API](/oss/langgraph/graph-api). We define a [conditional edge](/oss/langgraph/graph-api#conditional-edges) at the query generation step that will route to the query checker if a query is generated, or end if there are no tool calls present, such that the LLM has delivered a response to the query. +:::python ```python def should_continue(state: MessagesState) -> Literal[END, "check_query"]: messages = state["messages"] @@ -288,6 +532,47 @@ from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeSt display(Image(agent.get_graph().draw_mermaid_png())) ``` +::: +:::js +```typescript +function shouldContinue(state: typeof MessagesAnnotation.State): "check_query" | typeof END { + const messages = state.messages; + const lastMessage = messages[messages.length - 1]; + if (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0) { + return END; + } else { + return "check_query"; + } +} + +const builder = new StateGraph(MessagesAnnotation) + .addNode("list_tables", listTables) + .addNode("call_get_schema", callGetSchema) + .addNode("get_schema", getSchemaNode) + .addNode("generate_query", generateQuery) + .addNode("check_query", checkQuery) + .addNode("run_query", runQueryNode) + .addEdge(START, "list_tables") + .addEdge("list_tables", "call_get_schema") + .addEdge("call_get_schema", "get_schema") + .addEdge("get_schema", "generate_query") + .addConditionalEdges("generate_query", shouldContinue) + .addEdge("check_query", "run_query") + .addEdge("run_query", "generate_query"); + +const agent = builder.compile(); +``` +We visualize the application below: +```typescript +import * as fs from "node:fs/promises"; + +const drawableGraph = await agent.getGraphAsync(); +const image = await drawableGraph.drawMermaidPng(); +const imageBuffer = new Uint8Array(await image.arrayBuffer()); + +await fs.writeFile("graph.png", imageBuffer); +``` +::: SQL agent graph We can now invoke the graph: +:::python ```python question = "Which genre on average has the longest tracks?" @@ -304,6 +590,24 @@ for step in agent.stream( ): step["messages"][-1].pretty_print() ``` +::: +:::js +```typescript +const question = "Which genre on average has the longest tracks?"; + +const stream = await agent.stream( + { messages: [{ role: "user", content: question }] }, + { streamMode: "values" } +); + +for await (const step of stream) { + if (step.messages && step.messages.length > 0) { + const lastMessage = step.messages[step.messages.length - 1]; + console.log(lastMessage.toFormattedString()); + } +} +``` +::: ``` ================================ Human Message ================================= @@ -379,9 +683,16 @@ Name: sql_db_query The genre with the longest tracks on average is "Sci Fi & Fantasy," with an average track length of approximately 2,911,783 milliseconds. Other genres with relatively long tracks include "Science Fiction," "Drama," "TV Shows," and "Comedy." ``` +:::python See [LangSmith trace](https://smith.langchain.com/public/94b8c9ac-12f7-4692-8706-836a1f30f1ea/r) for the above run. +::: +:::js + +See [LangSmith trace](https://smith.langchain.com/public/a6a96896-686a-4040-b9b5-28d701453d6f/r) for the above run. + +::: ## 6. Implement human-in-the-loop review @@ -390,6 +701,8 @@ It can be prudent to check the agent's SQL queries before they are executed for Here we leverage LangGraph's [human-in-the-loop](/oss/langgraph/interrupts) features to pause the run before executing a SQL query and wait for human review. Using LangGraph's [persistence layer](/oss/langgraph/persistence), we can pause the run indefinitely (or at least as long as the persistence layer is alive). Let's wrap the `sql_db_query` tool in a node that receives human input. We can implement this using the [interrupt](/oss/langgraph/interrupts) function. Below, we allow for input to approve the tool call, edit its arguments, or provide user feedback. + +:::python ```python from langchain_core.runnables import RunnableConfig from langchain.tools import tool @@ -423,11 +736,56 @@ def run_query_tool_with_interrupt(config: RunnableConfig, **tool_input): return tool_response ``` +::: +:::js +```typescript +import { RunnableConfig } from "@langchain/core/runnables"; +import { tool } from "@langchain/core/tools"; +import { interrupt } from "@langchain/langgraph"; + +const queryToolWithInterrupt = tool( + async (input, config: RunnableConfig) => { + const request = { + action: queryTool.name, + args: input, + description: "Please review the tool call", + }; + const response = interrupt([request]); // [!code highlight] + // approve the tool call + if (response.type === "accept") { + const toolResponse = await queryTool.invoke(input, config); + return toolResponse; + } + // update tool call args + else if (response.type === "edit") { + const editedInput = response.args.args; + const toolResponse = await queryTool.invoke(editedInput, config); + return toolResponse; + } + // respond to the LLM with user feedback + else if (response.type === "response") { + const userFeedback = response.args; + return userFeedback; + } else { + throw new Error(`Unsupported interrupt response type: ${response.type}`); + } + }, + { + name: queryTool.name, + description: queryTool.description, + schema: queryTool.schema, + } +); +``` +::: + The above implementation follows the [tool interrupt example](/oss/langgraph/interrupts#configuring-interrupts) in the broader [human-in-the-loop](/oss/langgraph/interrupts) guide. Refer to that guide for details and alternatives. Let's now re-assemble our graph. We will replace the programmatic check with human review. Note that we now include a [checkpointer](/oss/langgraph/persistence); this is required to pause and resume the run. + +:::python ```python from langgraph.checkpoint.memory import InMemorySaver @@ -459,7 +817,44 @@ builder.add_edge("run_query", "generate_query") checkpointer = InMemorySaver() # [!code highlight] agent = builder.compile(checkpointer=checkpointer) # [!code highlight] ``` +::: +:::js +```typescript +import { MemorySaver } from "@langchain/langgraph"; + +function shouldContinueWithHuman(state: typeof MessagesAnnotation.State): "run_query" | typeof END { + const messages = state.messages; + const lastMessage = messages[messages.length - 1]; + if (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0) { + return END; + } else { + return "run_query"; + } +} + +const runQueryNodeWithInterrupt = new ToolNode([queryToolWithInterrupt]); + +const builderWithHuman = new StateGraph(MessagesAnnotation) + .addNode("list_tables", listTables) + .addNode("call_get_schema", callGetSchema) + .addNode("get_schema", getSchemaNode) + .addNode("generate_query", generateQuery) + .addNode("run_query", runQueryNodeWithInterrupt) + .addEdge(START, "list_tables") + .addEdge("list_tables", "call_get_schema") + .addEdge("call_get_schema", "get_schema") + .addEdge("get_schema", "generate_query") + .addConditionalEdges("generate_query", shouldContinueWithHuman) + .addEdge("run_query", "generate_query"); + +const checkpointer = new MemorySaver(); // [!code highlight] +const agentWithHuman = builderWithHuman.compile({ checkpointer }); // [!code highlight] +``` +::: + We can invoke the graph as before. This time, execution is interrupted: + +:::python ```python import json @@ -482,6 +877,34 @@ for step in agent.stream( else: pass ``` +::: +:::js +```typescript +const config = { configurable: { thread_id: "1" } }; + +const question = "Which genre on average has the longest tracks?"; + +const stream = await agentWithHuman.stream( + { messages: [{ role: "user", content: question }] }, + { ...config, streamMode: "values" } +); + +for await (const step of stream) { + if (step.messages && step.messages.length > 0) { + const lastMessage = step.messages[step.messages.length - 1]; + console.log(lastMessage.toFormattedString()); + } +} + +// Check for interrupts +const state = await agentWithHuman.getState(config); +if (state.next.length > 0) { + console.log("\nINTERRUPTED:"); + console.log(JSON.stringify(state.tasks[0].interrupts[0], null, 2)); +} +``` +::: + ``` ... @@ -495,6 +918,8 @@ INTERRUPTED: } ``` We can accept or edit the tool call using [Command](/oss/langgraph/use-graph-api#combine-control-flow-and-state-updates-with-command): + +:::python ```python from langgraph.types import Command @@ -515,6 +940,26 @@ for step in agent.stream( else: pass ``` +::: +:::js +```typescript +import { Command } from "@langchain/langgraph"; + +const resumeStream = await agentWithHuman.stream( + new Command({ resume: { type: "accept" } }), + // new Command({ resume: { type: "edit", args: { query: "..." } } }), + { ...config, streamMode: "values" } +); + +for await (const step of resumeStream) { + if (step.messages && step.messages.length > 0) { + const lastMessage = step.messages[step.messages.length - 1]; + console.log(lastMessage.toFormattedString()); + } +} +``` +::: + ``` ================================== Ai Message ================================== Tool Calls: @@ -535,9 +980,3 @@ Refer to the [human-in-the-loop guide](/oss/langgraph/interrupts) for details. ## Next steps Check out the [Evaluate a graph](/langsmith/evaluate-graph) guide for evaluating LangGraph applications, including SQL agents like this one, using LangSmith. -::: -:::js -## Under construction - -This tutorial has not yet been implemented in Typescript. Refer to the LangChain [SQL agent guide](/oss/langchain/sql-agent) for a reference implementation. -:::