In [0]:
!pip install langchain langchain-groq langchain-community pyspark

Collecting langchain
  Downloading langchain-1.2.5-py3-none-any.whl.metadata (4.9 kB)
Collecting langchain-groq
  Downloading langchain_groq-1.1.1-py3-none-any.whl.metadata (2.4 kB)
Collecting langchain-community
  Downloading langchain_community-0.4.1-py3-none-any.whl.metadata (3.0 kB)
Collecting langchain-core<2.0.0,>=1.2.7 (from langchain)
  Downloading langchain_core-1.2.7-py3-none-any.whl.metadata (3.7 kB)
Collecting langgraph<1.1.0,>=1.0.2 (from langchain)
  Downloading langgraph-1.0.6-py3-none-any.whl.metadata (7.4 kB)
Collecting groq<1.0.0,>=0.30.0 (from langchain-groq)
  Downloading groq-0.37.1-py3-none-any.whl.metadata (16 kB)
Collecting langchain-classic<2.0.0,>=1.0.0 (from langchain-community)
  Downloading langchain_classic-1.0.1-py3-none-any.whl.metadata (4.2 kB)
Collecting SQLAlchemy<3.0.0,>=1.4.0 (from langchain-community)
  Downloading sqlalchemy-2.0.45-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (9.5 kB)
Collecting request

In [0]:
import os
from langchain_groq import ChatGroq
from langchain_community.utilities import SparkSQL
from langchain_classic.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

# 1. Setup API Key
os.environ["GROQ_API_KEY"] = ""  # Replace with your Groq Key

# Initialize the LLM
llm = ChatGroq(
    model="llama-3.3-70b-versatile", 
    temperature=0
)

In [0]:
# Creates a "Virtual" Database representation of my Gold Tables
# We only expose the specific tables we want the AI to query
catalog = "spotify_cata"
schema = "gold"
tables = ["fact_streams", "dim_artists", "dim_tracks"]

# LangChain Spark SQL wrapper
spark_sql = SparkSQL(schema=schema, catalog=catalog)
spark_sql.dialect = 'spark'

In [0]:
from langchain_classic.chains import create_sql_query_chain
from langchain_core.prompts import PromptTemplate

# Custom prompt that instructs LLM to return only SQL
custom_prompt = PromptTemplate.from_template(
    """Given an input question, create a syntactically correct {dialect} query to run.
    
Only use the following tables:
{table_info}

Question: {input}

Return ONLY the SQL query without any explanation, markdown formatting, or additional text.
SQL Query:"""
)

# 1. Chain that generates the SQL query with custom prompt
write_query = create_sql_query_chain(llm, spark_sql, prompt=custom_prompt)

# 2. Function that executes the SQL query using Spark
def execute_query(query):
    result = spark.sql(query.strip())
    return result.toPandas().to_string()

# 3. Combine them: Write Query -> Execute Query -> Answer
chain = write_query | execute_query

In [0]:
# Question 1
response = chain.invoke({"question": "Who are the top 3 artists by total streams?"})
print(response)

        artist_name  total_streams
0    Kirsten Harris             13
1      Kendra Bates             11
2  Jennifer Montoya             10


In [0]:
# Question 2 (Complex Join)
response = chain.invoke({"question": "What is the average duration of songs in the Pop genre?"})
print(response)

   average_duration
0        212.715789
