<a href="https://colab.research.google.com/github/kavyajeetbora/foursquare_ai/blob/master/notebooks/03_NL2SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

References for this tutorial:

1. [NLP to Sql using Langchain](https://blog.futuresmart.ai/mastering-natural-language-to-sql-with-langchain-nl2sql)
2. [NLP to Sql using Langchain - Notebook](https://github.com/PradipNichite/Youtube-Tutorials/blob/main/Langchain_NL2SQL_2024.ipynb)


## Setup Environment

In [None]:
!pip install -q langchain langchain-community duckdb langchain-openai duckdb-engine langgraph anytree "langchain[openai]"

In [None]:
import pandas as pd
from anytree import Node, RenderTree

def build_anytree_from_csv(csv_path_or_df):
    # Accept file path or DataFrame
    if isinstance(csv_path_or_df, str):
        df = pd.read_csv(csv_path_or_df)
    else:
        df = csv_path_or_df

    # Specify the order of columns representing the hierarchy
    level_cols = [
        "Category_Level_1",
        "Category_Level_2",
        "Category_Level_3",
        "Category_Level_4",
        "Category_Name"  # Category_Name is always the leaf
    ]

    root = Node("Root")
    nodes = {"Root": root}

    for _, row in df.iterrows():
        parent = root
        path = []
        for col in level_cols:
            category = row.get(col, None)
            if pd.notnull(category) and str(category).strip().lower() != "nan":
                category = str(category).strip()
                path.append(category)
                node_key = "/".join(path)
                if node_key not in nodes:
                    nodes[node_key] = Node(category, parent=parent)
                parent = nodes[node_key]
    return root

# Example usage:
# tree_root = build_anytree_from_csv('your_categories.csv')
# for pre, _, node in RenderTree(tree_root):
#     print(f"{pre}{node.name}")

## Setting Up database server

In [None]:
import json
from dotenv import load_dotenv
import os

import duckdb
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine

from langchain_openai import ChatOpenAI  # Or from langchain_groq import ChatGroq
from langchain.chains import create_sql_query_chain
from langchain_core.prompts import PromptTemplate
from langsmith import traceable

# Load environment variables from .env file
load_dotenv(dotenv_path = '.env', override=True)


os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGSMITH_PROJECT"] = "foursqare_poi_data"
os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY") # Ensure you have a way to get this

In [None]:
def execute_sql(sql_query: str, con):
    try:
        # Replace {s3_path} if not already in the query
        sql_query = sql_query.format(s3_path=S3_PATH)
        result = con.execute(sql_query).fetchall()
        return result
    except Exception as e:
        return f"Error executing SQL: {str(e)}"

def get_duckdb_connection():
    con = duckdb.connect(database=':memory:')  # In-memory for simplicity; use a file path for persistence if needed
    con.execute("INSTALL httpfs;")
    con.execute("LOAD httpfs;")
    # Optional: Set S3 region if needed (public bucket, so usually not required)
    # con.execute("SET s3_region='us-east-1';")
    return con

def get_db_schema(S3_PATH, columns):

    data_schema = f"""Columns:\n"""

    sql_query = f"SELECT {",".join(columns)} FROM read_parquet('{S3_PATH}') WHERE 1=1 AND country='IN'"

    for column in columns:
        sql_query += f" AND {column} IS NOT NULL"

    sql_query += f" LIMIT 5;"
    sample_result = execute_sql(sql_query, duckdb_conn)

    schema_details = execute_sql(f'DESCRIBE {sql_query}', duckdb_conn)

    for i, column in enumerate(columns):
        data_type = schema_details[i][1]
        sample_values = ",".join([str(r[i]) for r in sample_result])
        data_schema += f"{i+1}. Name: {column} | Data Type: {data_type} | Sample values: {sample_values}\n"
        data_schema += "\n"
    return data_schema

## Setup Database Connection

In [None]:
# Get the DuckDB connection
duckdb_conn = get_duckdb_connection()

# Create a SQLAlchemy engine from the DuckDB connection
# Explicitly specify the duckdb dialect
engine = create_engine('duckdb:///:memory:')


# Wrap in Langchain's SQLDatabase
db = SQLDatabase(engine=engine, sample_rows_in_table_info=2)  # sample_rows_in_table_info helps LLM understand schema

## Get the Table Info

In [None]:
# Define the S3 path (make this configurable)
S3_PATH = "s3://fsq-os-places-us-east-1/release/dt=2025-09-09/places/parquet/places-*.zstd.parquet"
columns = ['fsq_place_id', 'name', 'latitude', 'longitude', 'postcode', 'fsq_category_labels']
data_schema = get_db_schema(S3_PATH, columns)

print(data_schema)

## Load Environment Variables

Explore the chatgpt model pricing here:
https://platform.openai.com/docs/pricing

- ✅ Cheapest overall → gpt-5-nano ($0.45 per 2M tokens processed)

- ⚖️ Second cheapest → gpt-4.1-nano ($0.50)

- Use openai tokenizer to count the number of tokens for given input: https://platform.openai.com/tokenizer

In [None]:
# Initialize LLM (replace with your API key)
llm = ChatOpenAI(model="gpt-4o-mini", temperature=1)  # Or ChatGroq(model="grok-beta", api_key="YOUR_XAI_API_KEY")

# Pre-prompt for NL-to-SQL (customize this!)
# Updated prompt to include required input variables for create_sql_query_chain
# and incorporate the S3_PATH as a partial variable
SQL_PROMPT_TEMPLATE = """
You are a SQL expert for querying Foursquare POI data using DuckDB. The data is in Parquet files at {s3_path}.
The schema includes columns: name (string), country (string, e.g., 'IN' for India), location (struct with lat/lng), categories (array), etc.

Based on the schema below and the user's question, generate a valid DuckDB SQL query.
Schema:
{table_info}

Contraints:
Always restrict the oulets aka POI withing INDIA so always use country = 'IN' in WHERE clause in all generated SQL queries.
If any country is asked by the uers just say something like beyond scope and donot generate any sql prompt"

User question: {input}
Limit the results to {top_k}.

Generate a valid DuckDB SQL query to answer this. Always use read_parquet('{s3_path}') as the data source.
Do not add explanations; output only the SQL query.
Examples:
- Question: How many Starbucks Outlets are there in India?
  SQL: SELECT COUNT(*) as count FROM read_parquet('{s3_path}') WHERE LOWER(name) LIKE '%starbucks%' OR '%starbucks%' IN (SELECT LOWER(unnest(fsq_category_labels))))
AND country = 'IN';
"""

# Create the prompt template with the required input variables and partial variable for S3_PATH
sql_prompt = PromptTemplate(
    input_variables=["input", "top_k", "table_info"],
    template=SQL_PROMPT_TEMPLATE,
    partial_variables={"s3_path": S3_PATH}
)

# Now create the chain using the updated prompt
sql_chain = create_sql_query_chain(llm=llm, db=db, prompt=sql_prompt)

Track the LLM traces in this dashboard: https://smith.langchain.com/

In [None]:
question = "How many number of biryani outlets are there in India ?"
sql_input = {"input": question, "question": question, "s3_path": S3_PATH, "top_k": 1, "table_info": data_schema} # Include both 'input' and 'question' keys
sql_query = sql_chain.invoke(sql_input).strip()  # Strip any extra whitespace
print(sql_query)

In [None]:
RESPONSE_PROMPT_TEMPLATE = """
Query results: {results}

Based on the results, provide a friendly, easy-to-understand answer to the user's question: {question}
Keep it concise and natural.
"""

response_prompt = PromptTemplate.from_template(RESPONSE_PROMPT_TEMPLATE)
response_chain = response_prompt | llm  # Simple chain: prompt -> LLM

In [None]:
question = "What is the count of retail stores in India?"
sql_input = {"input": question, "question": question, "s3_path": S3_PATH, "top_k": 10} # Include both 'input' and 'question' keys

# Wrap the chain invocation with @traceable
@traceable
def get_sql_query(sql_input):
    return sql_chain.invoke(sql_input).strip()  # Strip any extra whitespace

sql_query = get_sql_query(sql_input)

# Remove markdown code block formatting if present
if sql_query.startswith("```sql") and sql_query.endswith("```"):
    sql_query = sql_query[6:-3].strip() # Remove "```sql" and "```" and strip whitespace

print(sql_query)

result = execute_sql(sql_query, duckdb_conn)

# Step 3: Generate friendly response
response_input = {"results": json.dumps(result), "question": question}  # Convert results to string/JSON

@traceable
def generate_friendly_response(response_input):
    return response_chain.invoke(response_input).content

friendly_response = generate_friendly_response(response_input)

print(friendly_response)

In [None]:
query = "SELECT unnest({'a': [1, 2, 3], 'b': 88}, recursive := true);"
execute_sql(query, duckdb_conn)

## Business Use Cases

1. How many Pizza Hut locations are there in India?
2. What is the count of convenience stores in India?
3. List the top 10 cities in India with the most fast food restaurants.
4. How many grocery stores are there in New Delhi?
5. Find beverage stores near latitude 19.0760, longitude 72.8777 (Mumbai).
6. What is the number of cafes in India that were created after 2024-01-01?
7. Count of McDonald's locations in India.
8. Top 5 states in India with the highest number of restaurants.
9. How many supermarkets have been closed in India this year?
10. List fast food places in Bengaluru with a website.
11. What is the average latitude of convenience stores in India?
12. Count of POIs categorized as 'Food and Beverage Shop' in India.
13. Find Pizza Hut outlets in Mumbai, India.
14. How many beverage-related POIs are there in India?
15. Top cities in India with the most health food stores (potential competitors to sugary beverages).
16. Count of restaurants with Instagram handles in India.
17. List closed fast food locations in India since 2023.
18. How many grocery stores are in postal code 400001 (Mumbai)?
19. Find nearby cafes to latitude 28.7041, longitude 77.1025 (Delhi) within a small radius.
20. What is the total number of retail food stores in India that have a telephone number?