Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Spark SQL support (#4602) #4956

Merged
merged 3 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
348 changes: 348 additions & 0 deletions docs/modules/agents/toolkits/examples/spark_sql.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions langchain/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
create_pbi_agent,
create_pbi_chat_agent,
create_spark_dataframe_agent,
create_spark_sql_agent,
create_sql_agent,
create_vectorstore_agent,
create_vectorstore_router_agent,
Expand Down Expand Up @@ -59,6 +60,7 @@
"create_pbi_agent",
"create_pbi_chat_agent",
"create_spark_dataframe_agent",
"create_spark_sql_agent",
"create_sql_agent",
"create_vectorstore_agent",
"create_vectorstore_router_agent",
Expand Down
4 changes: 4 additions & 0 deletions langchain/agents/agent_toolkits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain.agents.agent_toolkits.python.base import create_python_agent
from langchain.agents.agent_toolkits.spark.base import create_spark_dataframe_agent
from langchain.agents.agent_toolkits.spark_sql.base import create_spark_sql_agent
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
from langchain.agents.agent_toolkits.sql.base import create_sql_agent
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_toolkits.vectorstore.base import (
Expand All @@ -41,6 +43,7 @@
"create_vectorstore_agent",
"JsonToolkit",
"SQLDatabaseToolkit",
"SparkSQLToolkit",
"NLAToolkit",
"PowerBIToolkit",
"OpenAPIToolkit",
Expand All @@ -50,6 +53,7 @@
"VectorStoreRouterToolkit",
"create_pandas_dataframe_agent",
"create_spark_dataframe_agent",
"create_spark_sql_agent",
"create_csv_agent",
"ZapierToolkit",
"GmailToolkit",
Expand Down
1 change: 1 addition & 0 deletions langchain/agents/agent_toolkits/spark_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Spark SQL agent."""
56 changes: 56 additions & 0 deletions langchain/agents/agent_toolkits/spark_sql/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Spark SQL agent."""
from typing import Any, Dict, List, Optional

from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain


def create_spark_sql_agent(
llm: BaseLanguageModel,
toolkit: SparkSQLToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX,
suffix: str = SQL_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a sql agent from an LLM and tools."""
tools = toolkit.get_tools()
prefix = prefix.format(top_k=top_k)
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)
21 changes: 21 additions & 0 deletions langchain/agents/agent_toolkits/spark_sql/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# flake8: noqa

SQL_PREFIX = """You are an agent designed to interact with Spark SQL.
Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.
"""

SQL_SUFFIX = """Begin!

Question: {input}
Thought: I should look at the tables in the database to see what I can query.
{agent_scratchpad}"""
36 changes: 36 additions & 0 deletions langchain/agents/agent_toolkits/spark_sql/toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Toolkit for interacting with Spark SQL."""
from typing import List

from pydantic import Field

from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool
from langchain.tools.spark_sql.tool import (
InfoSparkSQLTool,
ListSparkSQLTool,
QueryCheckerTool,
QuerySparkSQLTool,
)
from langchain.utilities.spark_sql import SparkSQL


class SparkSQLToolkit(BaseToolkit):
"""Toolkit for interacting with Spark SQL."""

db: SparkSQL = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True

def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
QuerySparkSQLTool(db=self.db),
InfoSparkSQLTool(db=self.db),
ListSparkSQLTool(db=self.db),
QueryCheckerTool(db=self.db, llm=self.llm),
]
1 change: 1 addition & 0 deletions langchain/tools/spark_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tools for interacting with Spark SQL."""
14 changes: 14 additions & 0 deletions langchain/tools/spark_sql/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# flake8: noqa
QUERY_CHECKER = """
{query}
Double check the Spark SQL query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query."""
152 changes: 152 additions & 0 deletions langchain/tools/spark_sql/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# flake8: noqa
"""Tools for interacting with Spark SQL."""
from typing import Any, Dict, Optional

from pydantic import BaseModel, Extra, Field, root_validator

from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain.utilities.spark_sql import SparkSQL
from langchain.tools.base import BaseTool
from langchain.tools.spark_sql.prompt import QUERY_CHECKER


class BaseSparkSQLTool(BaseModel):
"""Base tool for interacting with Spark SQL."""

db: SparkSQL = Field(exclude=True)

# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True
extra = Extra.forbid


class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for querying a Spark SQL."""

name = "query_sql_db"
description = """
Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL.
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""

def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Execute the query, return the results or an error message."""
return self.db.run_no_throw(query)

async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("QuerySqlDbTool does not support async")


class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for getting metadata about a Spark SQL."""

name = "schema_sql_db"
description = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list_tables_sql_db first!

Example Input: "table1, table2, table3"
"""

def _run(
self,
table_names: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
return self.db.get_table_info_no_throw(table_names.split(", "))

async def _arun(
self,
table_name: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("SchemaSqlDbTool does not support async")


class ListSparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for getting tables names."""

name = "list_tables_sql_db"
description = "Input is an empty string, output is a comma separated list of tables in the Spark SQL."

def _run(
self,
tool_input: str = "",
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for a specific table."""
return ", ".join(self.db.get_usable_table_names())

async def _arun(
self,
tool_input: str = "",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("ListTablesSqlDbTool does not support async")


class QueryCheckerTool(BaseSparkSQLTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""

template: str = QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False)
name = "query_checker_sql_db"
description = """
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with query_sql_db!
"""

@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["query"]
),
)

if values["llm_chain"].prompt.input_variables != ["query"]:
raise ValueError(
"LLM chain for QueryCheckerTool need to use ['query'] as input_variables "
"for the embedded prompt"
)

return values

def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the LLM to check the query."""
return self.llm_chain.predict(query=query)

async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return await self.llm_chain.apredict(query=query)
2 changes: 2 additions & 0 deletions langchain/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from langchain.utilities.python import PythonREPL
from langchain.utilities.searx_search import SearxSearchWrapper
from langchain.utilities.serpapi import SerpAPIWrapper
from langchain.utilities.spark_sql import SparkSQL
from langchain.utilities.wikipedia import WikipediaAPIWrapper
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper

Expand All @@ -38,5 +39,6 @@
"PythonREPL",
"LambdaWrapper",
"PowerBIDataset",
"SparkSQL",
"MetaphorSearchAPIWrapper",
]