In [2]:
from pyspark.sql import SparkSession

spark = SparkSession.Builder()\
        .master("local[*]")\
            .appName("GenAI with SQL")\
            .getOrCreate()

In [6]:
from pyspark.sql.types import StructType, StructField
from pyspark.sql.types import DoubleType, IntegerType, StringType

schema = StructType([
    StructField("product_id", IntegerType()),
    StructField("product_name", StringType())
])

df = spark.read.format("csv")\
.schema(schema=schema)\
.load("input_data/product.csv")

df.show()

+----------+------------+
|product_id|product_name|
+----------+------------+
|       101|   toothpast|
|       102|       glass|
|       103|    horlicks|
|       104|       sugar|
|       105|     jaggery|
|       106|     cookies|
|       107|      candle|
|       108|       paper|
|       109|         pen|
|       110|        copy|
|       111|      pencil|
|       112|      snacks|
+----------+------------+



In [1]:
from langchain_community.utilities import SparkSQL ,SQLDatabase


mysql_URL = "mysql+mysqlconnector://root:root@localhost:3306/mydatabase"
print(f" mysql URL {mysql_URL} ")
try:
    SQLDatabase.from_uri(mysql_URL)
    #SparkSQL.from_uri(mysql_URL)
    print("Database connection successful!")
except RuntimeError as e:
    print(f"Error occurred: {e}")

 mysql URL mysql+mysqlconnector://root:root@localhost:3306/mydatabase 


: 

In [None]:
from pyspark.sql import SparkSession
spark = SparkSession.builder \
    .appName("MySQL Connection") \
    .config("spark.jars", "jars/mysql-connector-9.1.0.jar") \
    .getOrCreate()

# Then use spark.read for the JDBC connection
df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:mysql://localhost:3306/mydatabase") \
    .option("driver","com.mysql.cj.jdbc.Driver")\
    .option("dbtable", "athlete") \
    .option("user", "root") \
    .option("password", "root") \
    .load()

df.write.saveAsTable("athlete_table")


In [14]:
#get environment variables 
import os
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
DATABASE = os.getenv("DATABASE")

In [15]:
from langchain_community.agent_toolkits import SparkSQLToolkit, create_spark_sql_agent
from langchain_community.utilities.spark_sql import SparkSQL
#from langchain_openai import ChatOpenAI
import google.generativeai as genai
from langchain_google_genai import ChatGoogleGenerativeAI



spark_sql = SparkSQL(schema=DATABASE)
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash",api_key=GOOGLE_API_KEY) #ChatOpenAI(temperature=0)
toolkit = SparkSQLToolkit(db=spark_sql, llm=llm)
agent_executor = create_spark_sql_agent(llm=llm, toolkit=toolkit, verbose=True)

In [18]:
res = agent_executor.run("Describe the athelete table")




[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: [0m
Observation: [38;5;200m[1;3mathlete_table[0m
Thought:[32;1m[1;3mThought: I should see what the schema of the athlete_table is.
Action: schema_sql_db
Action Input: athlete_table[0m
Observation: [33;1m[1;3mCREATE TABLE `default`.`athlete_table` (
  `ID` INT,
  `Name` STRING,
  `Sex` STRING,
  `Age` STRING,
  `Height` STRING,
  `Weight` STRING,
  `Team` STRING,
  `NOC` STRING,
  `Games` STRING,
  `Year` STRING,
  `Season` STRING,
  `City` STRING,
  `Sport` STRING,
  `Event` STRING,
  `Medal` STRING)
;

/*
3 rows from athlete_table table:
ID	Name	Sex	Age	Height	Weight	Team	NOC	Games	Year	Season	City	Sport	Event	Medal
75056	Marco Marin	M	25	180	71	Italy	ITA	1988 Summer	1988	Summer	Seoul	Fencing	Fencing Men's Sabre, Individual	NA
1	A Dijiang	M	24	180	80	China	CHN	1992 Summer	1992	Summer	Barcelona	Basketball	Basketball Men's Basketball	NA
2	A Lamusi	M	23	170	60	China	CHN	2012 Sum

In [19]:
agent_executor.run("who is oldest athelete")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: [0m
Observation: [38;5;200m[1;3mathlete_table[0m
Thought:[32;1m[1;3mThought: I need to look at the schema of the athlete_table to see if it has an age column.
Action: schema_sql_db
Action Input: athlete_table[0m
Observation: [33;1m[1;3mCREATE TABLE `default`.`athlete_table` (
  `ID` INT,
  `Name` STRING,
  `Sex` STRING,
  `Age` STRING,
  `Height` STRING,
  `Weight` STRING,
  `Team` STRING,
  `NOC` STRING,
  `Games` STRING,
  `Year` STRING,
  `Season` STRING,
  `City` STRING,
  `Sport` STRING,
  `Event` STRING,
  `Medal` STRING)
;

/*
3 rows from athlete_table table:
ID	Name	Sex	Age	Height	Weight	Team	NOC	Games	Year	Season	City	Sport	Event	Medal
75056	Marco Marin	M	25	180	71	Italy	ITA	1988 Summer	1988	Summer	Seoul	Fencing	Fencing Men's Sabre, Individual	NA
1	A Dijiang	M	24	180	80	China	CHN	1992 Summer	1992	Summer	Barcelona	Basketball	Basketball Men's Basketball	NA
2	A Lamusi	M	

'Federico Antonio Mariscal Abascal, Alonso Mariscal Abascal, Lucien Marinier, Mohamed Jamshid Abadi, Kyros Marinis, Georgi Abadzhiev, Ilias Marinos, Sofia Marinou, Sofia Marinou, Georgi Abadzhiev are the oldest athletes.'

In [21]:
response = agent_executor.run("what is the name of youngest athelete")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db
Action Input: [0m
Observation: [38;5;200m[1;3mathlete_table[0m
Thought:[32;1m[1;3mThought: I should look at the schema of the athlete_table to see if I can find the age of the athletes.
Action: schema_sql_db
Action Input: athlete_table[0m
Observation: [33;1m[1;3mCREATE TABLE `default`.`athlete_table` (
  `ID` INT,
  `Name` STRING,
  `Sex` STRING,
  `Age` STRING,
  `Height` STRING,
  `Weight` STRING,
  `Team` STRING,
  `NOC` STRING,
  `Games` STRING,
  `Year` STRING,
  `Season` STRING,
  `City` STRING,
  `Sport` STRING,
  `Event` STRING,
  `Medal` STRING)
;

/*
3 rows from athlete_table table:
ID	Name	Sex	Age	Height	Weight	Team	NOC	Games	Year	Season	City	Sport	Event	Medal
75056	Marco Marin	M	25	180	71	Italy	ITA	1988 Summer	1988	Summer	Seoul	Fencing	Fencing Men's Sabre, Individual	NA
1	A Dijiang	M	24	180	80	China	CHN	1992 Summer	1992	Summer	Barcelona	Basketball	Basketball Men's Basketball	NA


In [16]:
from langchain.chains import create_sql_query_chain

query_chain=create_sql_query_chain(llm,spark_sql)
response=query_chain.invoke({"question":"what is total atheletes"})
print(response) 

AttributeError: 'SparkSQL' object has no attribute 'dialect'