# Spark SQL Flexible Agent, going beyond SQL queries
The classical ZeroShotAgent always forces to use a tool which makes sense when the tool can finish the work better than AI, however, there are many cases, AI is better than tools, esp. when generation is needed.

One example, when user uses Spark, sometimes, we want the model to give us suggestions of how to write some queries to assist our work instead of doing the query for us directly. (Letting the AI doing the query is a valid case but it's not the only case) The power of LLM is generation. It's able to provide more creative and complex answer, which includes providing suggestions to human.

In this notebook, we present how to enable AI's creativity in the classical MRKL framework through a freestyle tool.

In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
schema = "langchain_example"
spark.sql(f"CREATE DATABASE IF NOT EXISTS {schema}")
spark.sql(f"USE {schema}")
csv_file_path = "../toolkits/examples/titanic.csv"
table = "titanic"
spark.read.csv(csv_file_path, header=True, inferSchema=True).write.mode('overwrite').saveAsTable(table)
spark.table(table).show()

23/05/14 10:02:38 WARN Utils: Your hostname, Mikes-Mac-mini.local resolves to a loopback address: 127.0.0.1; using 192.168.68.115 instead (on interface en1)
23/05/14 10:02:38 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/05/14 10:02:38 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/05/14 10:02:39 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|
|          6|       0|     3|    Moran, Mr. James|  male|null|    0|    0|      

In [4]:
from langchain.agents import create_spark_sql_agent
from langchain.chat_models import ChatOpenAI
from langchain.spark_sql import SparkSQL

spark_sql = SparkSQL(schema=schema)
llm = ChatOpenAI(temperature=0, model_name="gpt-4")
agent_executor = create_spark_sql_agent(
    db=spark_sql,
    llm=llm,
    enable_memory=True,
    enable_freestyle=True,
    verbose=True,
)

In [5]:
agent_executor.run("give me two different ways to query the most expensive class")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: [0m
Observation: [38;5;200m[1;3mtitanic[0m
Thought:[32;1m[1;3mI need to see the schema of the titanic table to understand the columns and data.
Action: schema_sql_db
Action Input: titanic[0m
Observation: [33;1m[1;3mCREATE TABLE langchain_example.titanic (
  PassengerId INT,
  Survived INT,
  Pclass INT,
  Name STRING,
  Sex STRING,
  Age DOUBLE,
  SibSp INT,
  Parch INT,
  Ticket STRING,
  Fare DOUBLE,
  Cabin STRING,
  Embarked STRING)
;

/*
3 rows from titanic table:
PassengerId	Survived	Pclass	Name	Sex	Age	SibSp	Parch	Ticket	Fare	Cabin	Embarked
1	0	3	Braund, Mr. Owen Harris	male	22.0	1	0	A/5 21171	7.25	None	S
2	1	1	Cumings, Mrs. John Bradley (Florence Briggs Thayer)	female	38.0	1	0	PC 17599	71.2833	C85	C
3	1	3	Heikkinen, Miss. Laina	female	26.0	0	0	STON/O2. 3101282	7.925	None	S
*/[0m
Thought:[32;1m[1;3mI need to find the most expensive class, which is related to the Fare

'Here are two different ways to query the most expensive class:\n\n1. SELECT Pclass, MAX(Fare) as Max_Fare FROM titanic GROUP BY Pclass ORDER BY Max_Fare DESC LIMIT 1;\n\n2. SELECT Pclass, Fare FROM titanic WHERE Fare = (SELECT MAX(Fare) FROM titanic) LIMIT 1;'

In [6]:
agent_executor.run("are they accurate?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: [0m
Observation: [38;5;200m[1;3mtitanic[0m
Thought:[32;1m[1;3mI should check the schema of the titanic table to see if the columns I used in the queries are correct.
Action: schema_sql_db
Action Input: titanic[0m
Observation: [33;1m[1;3mCREATE TABLE langchain_example.titanic (
  PassengerId INT,
  Survived INT,
  Pclass INT,
  Name STRING,
  Sex STRING,
  Age DOUBLE,
  SibSp INT,
  Parch INT,
  Ticket STRING,
  Fare DOUBLE,
  Cabin STRING,
  Embarked STRING)
;

/*
3 rows from titanic table:
PassengerId	Survived	Pclass	Name	Sex	Age	SibSp	Parch	Ticket	Fare	Cabin	Embarked
1	0	3	Braund, Mr. Owen Harris	male	22.0	1	0	A/5 21171	7.25	None	S
2	1	1	Cumings, Mrs. John Bradley (Florence Briggs Thayer)	female	38.0	1	0	PC 17599	71.2833	C85	C
3	1	3	Heikkinen, Miss. Laina	female	26.0	0	0	STON/O2. 3101282	7.925	None	S
*/[0m
Thought:[32;1m[1;3mThe schema of the titanic table matches the col

'Yes, both queries are accurate.'

In [7]:
agent_executor.run("cool, then let's use the second solution and query the second most expensive")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: [0m
Observation: [38;5;200m[1;3mtitanic[0m
Thought:[32;1m[1;3mI know the table I need to query is the titanic table. I will modify the second query to find the second most expensive class.
Action: query_checker_sql_db
Action Input: SELECT Pclass, Fare FROM titanic WHERE Fare = (SELECT DISTINCT Fare FROM titanic ORDER BY Fare DESC LIMIT 1 OFFSET 1) LIMIT 1;[0m
Observation: [31;1m[1;3mThe original query seems to be correct. It is trying to find the second highest fare and its corresponding Pclass from the titanic dataset. Here's the query again:

SELECT Pclass, Fare FROM titanic WHERE Fare = (SELECT DISTINCT Fare FROM titanic ORDER BY Fare DESC LIMIT 1 OFFSET 1) LIMIT 1;[0m
Thought:[32;1m[1;3mThe query is correct, so I will execute it to find the second most expensive class.
Action: query_sql_db
Action Input: SELECT Pclass, Fare FROM titanic WHERE Fare = (SELECT DISTINCT Fare

'The second most expensive class is class 1 with a fare of 263.0.'

# Reading Logical, Physical Plans
In the example below, we will show how AI uses Logical and Physical plan to further explain the executed query.

In [2]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
schema = "langchain_example"
spark.sql(f"CREATE DATABASE IF NOT EXISTS {schema}")
spark.sql(f"USE {schema}")
csv_file_path = "../toolkits/examples/titanic.csv"
table = "titanic"
spark.read.csv(csv_file_path, header=True, inferSchema=True).write.mode('overwrite').saveAsTable(table)
spark.table(table).show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|
|          6|       0|     3|    Moran, Mr. James|  male|null|    0|    0|      

In [5]:
from langchain.agents import create_spark_sql_agent
from langchain.chat_models import ChatOpenAI
from langchain.spark_sql import SparkSQL

spark_sql = SparkSQL(schema=schema)
llm = ChatOpenAI(temperature=0, model_name="gpt-4")
agent_executor = create_spark_sql_agent(
    db=spark_sql,
    llm=llm,
    enable_memory=True,
    enable_freestyle=True,
    verbose=True,
)

In [8]:
agent_executor.run("how many tickets sold in the second most expensive class?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mtitanic[0m
Thought:[32;1m[1;3mI should check the schema of the titanic table to understand the columns and their data types.
Action: schema_sql_db
Action Input: "titanic"[0m
Observation: [33;1m[1;3mCREATE TABLE langchain_example.titanic (
  PassengerId INT,
  Survived INT,
  Pclass INT,
  Name STRING,
  Sex STRING,
  Age DOUBLE,
  SibSp INT,
  Parch INT,
  Ticket STRING,
  Fare DOUBLE,
  Cabin STRING,
  Embarked STRING)
;

/*
3 rows from titanic table:
PassengerId	Survived	Pclass	Name	Sex	Age	SibSp	Parch	Ticket	Fare	Cabin	Embarked
1	0	3	Braund, Mr. Owen Harris	male	22.0	1	0	A/5 21171	7.25	None	S
2	1	1	Cumings, Mrs. John Bradley (Florence Briggs Thayer)	female	38.0	1	0	PC 17599	71.2833	C85	C
3	1	3	Heikkinen, Miss. Laina	female	26.0	0	0	STON/O2. 3101282	7.925	None	S
*/[0m
Thought:[32;1m[1;3mNow that I know the schema of the titanic table, I c

'There were 184 tickets sold in the second most expensive class.'

In [9]:
agent_executor.run("how's it optimized by spark? explain the optimized logical plan to me.")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: ""[0m
Observation: [38;5;200m[1;3mtitanic[0m
Thought:[32;1m[1;3mI need to find the second most expensive class and count the number of tickets sold in that class. I will write a query to achieve this and then get the logical and physical plans.
Action: query_checker_sql_db
Action Input: SELECT COUNT(*) FROM titanic WHERE Pclass = (SELECT Pclass FROM (SELECT Pclass, AVG(Fare) as avg_fare FROM titanic GROUP BY Pclass ORDER BY avg_fare DESC LIMIT 1 OFFSET 1) as second_most_expensive_class)[0m
Observation: [31;1m[1;3mThe original query seems to be correct. Here it is reproduced:

SELECT COUNT(*) FROM titanic WHERE Pclass = (SELECT Pclass FROM (SELECT Pclass, AVG(Fare) as avg_fare FROM titanic GROUP BY Pclass ORDER BY avg_fare DESC LIMIT 1 OFFSET 1) as second_most_expensive_class)[0m
Thought:[32;1m[1;3mNow that I have the correct query, I will get the logical and physical plans 

'The optimized logical plan for the query is as follows:\n\n1. The query starts with a scalar subquery to find the second most expensive class by calculating the average fare for each class and ordering them in descending order. The subquery is then used in the main query.\n2. The main query performs a Filter operation to remove any rows with null Pclass values and to keep only the rows where Pclass equals the second most expensive class.\n3. The Aggregate operation is used to count the number of rows in the result, which represents the number of tickets sold in the second most expensive class.\n\nThe physical plan uses AdaptiveSparkPlan, HashAggregate, Exchange, and FileScan operations to execute the query efficiently.'