<a href="https://colab.research.google.com/github/kavyajeetbora/foursquare_ai/blob/master/notebooks/07_duckdb_ai_bot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup the Python Environment

In [1]:
!pip install --quiet duckdb jupysql duckdb-engine
## Langchain Framework
!pip install --quiet langchain langchain-community langchain-openai langgraph "langchain[openai]"

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/95.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m95.1/95.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.7/49.7 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m192.8/192.8 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.3/137.3 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m37.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m 

In [49]:
import duckdb
import os
import pandas as pd
from dotenv import load_dotenv
from google.colab import files

from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate

from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

from pydantic import BaseModel, Field, field_validator
import json
from langchain.prompts import ChatPromptTemplate


# Check if .env file exists, if not, prompt for upload
if not os.path.exists(".env"):
    uploaded = files.upload()
    if ".env" not in uploaded:
        print("Please upload a .env file to proceed.")
    else:
        with open(".env", "wb") as f:
            f.write(uploaded[".env"])
        load_dotenv(dotenv_path = ".env", override=True)
else:
    load_dotenv(dotenv_path = ".env", override=True)

In [3]:
os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGSMITH_PROJECT"] = "sql_agent_project"
os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY") # Ensure you have a way to get this

## Data Engineering Part

- Here we will prepare the data for our analysis
- Clean Data
- Load it to local storage for faster loading and querying

In [4]:
# Initialize DuckDB connection
con = duckdb.connect()

# Load required extensions
con.execute("INSTALL httpfs; LOAD httpfs; INSTALL spatial; LOAD spatial;")

s3_places_path = 's3://fsq-os-places-us-east-1/release/dt=2025-09-09/places/parquet/places-*.zstd.parquet'
s3_categories_path = 's3://fsq-os-places-us-east-1/release/dt=2025-09-09/categories/parquet/categories.zstd.parquet'

# Execute the SELECT query and create a view
con.execute(f"""
CREATE OR REPLACE VIEW places_with_categories AS
WITH places AS (
    SELECT
        DISTINCT UNNEST(P.fsq_category_ids) as fsq_category_id,
        name,
        postcode,
        address,
        region,
        ST_Point(longitude, latitude) AS geom
    FROM read_parquet('{s3_places_path}') AS P
    WHERE latitude IS NOT NULL AND longitude IS NOT NULL AND country='IN'
),
places_with_categories AS (
    SELECT
        P.name AS name,
        C.level1_category_name AS category_level_1,
        C.level2_category_name AS category_level_2,
        postcode,
        address,
        region,
        P.geom
    FROM places AS P
    JOIN read_parquet('{s3_categories_path}') AS C
    ON P.fsq_category_id = C.category_id
)
SELECT
    name,
    category_level_1,
    category_level_2,
    address,
    region,
    postcode,
    geom
FROM places_with_categories;
""")

# Export the view to GeoParquet
con.execute("COPY (SELECT * FROM places_with_categories) TO 'output.geoparquet' WITH (FORMAT PARQUET, CODEC ZSTD);")

## Check the total count of the database:
# con.execute("SELECT COUNT(*) FROM places_with_categories;")
# result = con.fetchone()[0]
# print(result)
## Around 1358392 points are there

# Close the connection
con.close()

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

# AI ChatBot

-

## Setup Prompt Templates

- System prompts

In [31]:
system_message = """
Given an input question, create a syntactically correct DuckDB's SQL dialect query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most {top_k} results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
{table_info}
"""

user_prompt = "Question: {input}"

query_prompt_template = ChatPromptTemplate(
    [("system", system_message), ("user", user_prompt)]
)

for message in query_prompt_template.messages:
    message.pretty_print()



Given an input question, create a syntactically correct DuckDB's SQL dialect query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m


Question: [33;1m[1;3m{input}[0m


## Helper functions

- Setting up connection with duckdb
- Then generating shema details for all the tables in the database
- Define the input and the output structure format for all the parameters are going to be passed on to the llm

In [68]:
# Pydantic Models with @field_validator
class State(BaseModel):
    question: str = Field(
        default="",
        description="The natural language question or input provided by the user, intended to be translated into a DuckDB SQL query. Example: 'Show me all customers with orders above $100.'"
    )
    query: str = Field(
        default="",
        description="The generated DuckDB SQL query derived from the user's natural language question. Must conform to DuckDB's SQL syntax and conventions. Example: 'SELECT * FROM customers WHERE order_total > 100;'"
    )
    result: str = Field(
        default="",
        description="The output or result set returned by executing the DuckDB SQL query against the database, formatted as a JSON string. Example: '[{\"id\": 1, \"name\": \"Alice\", \"order_total\": 150}]'"
    )
    answer: str = Field(
        default="",
        description="The final natural language response generated for the user, summarizing or explaining the query results in a conversational manner. Example: 'Here are the customers with orders above $100: Alice with $150.'"
    )

    @field_validator("result", mode="after")
    @classmethod
    def validate_result(cls, result: str) -> str:
        try:
            json.loads(result)
        except json.JSONDecodeError:
            raise ValueError("Result must be a valid JSON string")
        return result

class QueryOutput(BaseModel):
    query: str = Field(..., description="Syntactically valid DuckDB SQL query")

    @field_validator("query", mode="after")
    @classmethod
    def validate_query(cls, query: str) -> str:
        if not query.strip().endswith(";"):
            raise ValueError("DuckDB query must end with a semicolon")
        return query

# Prompt Setup (unchanged)
system_message = """
Given an input question, create a syntactically correct DuckDB's SQL dialect query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most {top_k} results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
{table_info}
"""
user_prompt = "Question: {input}"
query_prompt_template = ChatPromptTemplate.from_messages([("system", system_message), ("user", user_prompt)])

class FourSquareChatBot:
    """NLP-to-SQL chatbot for querying DuckDB databases, optimized for Parquet files and FourSquare data."""

    def __init__(self, data_path: str, columns: list[str], llm, database: str = ":memory:"):
        """
        Initialize the chatbot with a DuckDB connection and schema.

        Args:
            data_path (str): Path to the Parquet file (e.g., S3 URL or local path).
            columns (list[str]): List of column names to include in the schema.
            llm: Language model instance for generating SQL queries and answers.
            database (str): DuckDB database path (default: ':memory:' for in-memory).
        """
        self.data_path = data_path
        self.columns = columns
        self.llm = llm
        self.conn = self._get_duckdb_connection(database)
        self.table_info = self._get_db_schema(limit=5)

    def _get_duckdb_connection(self, database: str) -> duckdb.DuckDBPyConnection:
        """Create and configure a DuckDB connection."""
        con = duckdb.connect(database=database)
        con.execute("INSTALL httpfs;")
        con.execute("LOAD httpfs;")
        con.execute("INSTALL spatial;")
        con.execute("LOAD spatial;")
        # Optional: Set S3 region if needed
        # con.execute("SET s3_region='us-east-1';")
        return con

    def _get_db_schema(self, limit=5):

        """Generate schema information from the Parquet file with sample values."""

        data_schema = f"""Columns:\n"""

        sql_query = f"SELECT {",".join(self.columns)} FROM read_parquet('{self.data_path}') WHERE 1=1"

        for column in self.columns:
            sql_query += f" AND {column} IS NOT NULL"

        sql_query += f" LIMIT {limit};"
        sample_result = self._execute_sql(sql_query)

        schema_details = self._execute_sql(f'DESCRIBE {sql_query}')

        for i, column in enumerate(self.columns):
            data_type = schema_details[i][1]
            sample_values = ",".join([str(r[i]) for r in sample_result])
            data_schema += f"{i+1}. Name: {column} | Data Type: {data_type} | Sample values: {sample_values}"
            data_schema += "\n"
        return data_schema

    def _execute_sql(self, sql_query: str) -> list:
        """Execute a SQL query and return results."""
        try:
            result = self.conn.execute(sql_query).fetchall()
            return result
        except Exception as e:
            return f"Error executing SQL: {str(e)}"

    def generate_sql_query(self, state: State) -> QueryOutput:
        """Generate a DuckDB SQL query from the user's question."""
        prompt = query_prompt_template.invoke(
            {
                "top_k": 10,
                "table_info": self.table_info,
                "input": state.question
            }
        )
        response = self.llm.invoke(prompt)
        return QueryOutput(query=response.content)

    def generate_answer(self, state: State) -> dict:
        """Generate a conversational answer using query results."""
        # Execute the query and store JSON result
        result = self._execute_sql(state.query)
        if isinstance(result, str) and result.startswith("Error"):
            state.result = json.dumps({"error": result})
            state.answer = f"Sorry, I couldn't process your query due to an error: {result}"
        else:
            columns = [desc[0] for desc in self.conn.description]
            state.result = json.dumps([dict(zip(columns, row)) for row in result])

            # Generate conversational answer
            prompt = (
                "Given the following user question, corresponding SQL query, "
                "and SQL result, answer the user question in a conversational manner.\n\n"
                f"Question: {state.question}\n"
                f"SQL Query: {state.query}\n"
                f"SQL Result: {state.result}"
            )
            response = self.llm.invoke(prompt)
            state.answer = response.content

        return {"state": state}

    def process_question(self, question: str) -> dict:
        """Process a user question end-to-end and return the updated state."""
        state = State(question=question)
        query_output = self.generate_sql_query(state)
        state.query = query_output.query
        result = self.generate_answer(state)
        return result

    def __del__(self):
        """Clean up by closing the DuckDB connection."""
        if hasattr(self, "conn"):
            self.conn.close()

In [69]:
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

fsq_chat_bot = FourSquareChatBot(
    data_path = "/content/output.geoparquet",
    columns = ['name', 'category_level_1', 'category_level_2', 'address', 'region', 'postcode'],
    llm = llm
)

VARCHAR
VARCHAR
VARCHAR
VARCHAR
VARCHAR
VARCHAR


In [70]:
print(fsq_chat_bot.table_info)

Columns:
1. Name: name | Data Type: VARCHAR | Sample values: Indane - Boham Gramin Vitrak,Gulf Carstop - Gogoi Automobile,Indane - Wangcha Gas Agency,Axis Bank,Axis Bank ATM
2. Name: category_level_1 | Data Type: VARCHAR | Sample values: Travel and Transportation,Retail,Travel and Transportation,Business and Professional Services,Business and Professional Services
3. Name: category_level_2 | Data Type: VARCHAR | Sample values: Fuel Station,Automotive Retail,Fuel Station,Financial Service,Financial Service
4. Name: address | Data Type: VARCHAR | Sample values: Ground Floor,Ground Floor,Bagh Moria Gaon,Ground Floor,1st Floor, NH 52,1st Floor,NH 52
5. Name: region | Data Type: VARCHAR | Sample values: Arunachal Pradesh,Assam,Arunachal Pradesh,Arunachal Pradesh,Arunachal Pradesh
6. Name: postcode | Data Type: VARCHAR | Sample values: 792130,785692,792131,791110,791110



In [53]:
class FourSquareChatBot:


def get_duckdb_connection():
    con = duckdb.connect(database=':memory:')  # In-memory for simplicity; use a file path for persistence if needed
    con.execute("INSTALL httpfs;")
    con.execute("LOAD httpfs;")
    con.execute("INSTALL spatial;")
    con.execute("LOAD spatial;")
    # Optional: Set S3 region if needed (public bucket, so usually not required)
    # con.execute("SET s3_region='us-east-1';")
    return con

def get_db_schema(DATA_PATH, columns, duckdb_conn, limit=5):

    data_schema = f"""Columns:\n"""

    sql_query = f"SELECT {",".join(columns)} FROM read_parquet('{DATA_PATH}') WHERE 1=1"

    for column in columns:
        sql_query += f" AND {column} IS NOT NULL"

    sql_query += f" LIMIT {limit};"
    sample_result = execute_sql(sql_query, duckdb_conn)

    schema_details = execute_sql(f'DESCRIBE {sql_query}', duckdb_conn)

    for i, column in enumerate(columns):
        data_type = schema_details[i][1]
        sample_values = ",".join([str(r[i]) for r in sample_result])
        data_schema += f"{i+1}. Name: {column} | Data Type: {data_type} | Sample values: {sample_values}"
        data_schema += "\n"
    return data_schema


def execute_sql(sql_query: str, con):
    try:
        # Replace {s3_path} if not already in the query
        result = con.execute(sql_query).fetchall()
        return result
    except Exception as e:
        return f"Error executing SQL: {str(e)}"


def generate_sql_query(state: State, table_info: str) -> QueryOutput:
    "Generate SQL Query to fetch information"
    prompt = query_prompt_template.invoke(
        {
            "top_k": 10,
            "table_info": table_info,
            "input": state.question
        }
    )
    response = llm.invoke(prompt)
    return response

# Generate answer (already correct)
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f"Question: {state.question}\n"
        f"SQL Query: {state.query}\n"
        f"SQL Result: {state.result}"
    )
    response = llm.invoke(prompt)
    return {"answer": response.content}

Table: data
Columns:
1. Name: name | Data Type: name | Sample values: Indane - Boham Gramin Vitrak,Gulf Carstop - Gogoi Automobile,Indane - Wangcha Gas Agency,Axis Bank,Axis Bank ATM
2. Name: category_level_1 | Data Type: category_level_1 | Sample values: Travel and Transportation,Retail,Travel and Transportation,Business and Professional Services,Business and Professional Services
3. Name: category_level_2 | Data Type: category_level_2 | Sample values: Fuel Station,Automotive Retail,Fuel Station,Financial Service,Financial Service
4. Name: address | Data Type: address | Sample values: Ground Floor,Ground Floor,Bagh Moria Gaon,Ground Floor,1st Floor, NH 52,1st Floor,NH 52
5. Name: region | Data Type: region | Sample values: Arunachal Pradesh,Assam,Arunachal Pradesh,Arunachal Pradesh,Arunachal Pradesh
6. Name: postcode | Data Type: postcode | Sample values: 792130,785692,792131,791110,791110



In [41]:
duckdb_conn = get_duckdb_connection()

DATA_PATH = "/content/output.geoparquet"
columns = ['name', 'category_level_1', 'category_level_2', 'address', 'region', 'postcode']

table_info = get_db_schema(DATA_PATH, columns, duckdb_conn, limit=10)
print(table_info)

Columns:
1. Name: name | Data Type: VARCHAR | Sample values: Indane - Boham Gramin Vitrak,Gulf Carstop - Gogoi Automobile,Indane - Wangcha Gas Agency,Axis Bank,Axis Bank ATM,HDFC ERGO Insurance Agent: Partha P Bhattacharjee,HDFC ERGO Insurance Agent: Mintu Roy,HDFC Bank ATM,HDFC Bank,Punjab National Bank
2. Name: category_level_1 | Data Type: VARCHAR | Sample values: Travel and Transportation,Retail,Travel and Transportation,Business and Professional Services,Business and Professional Services,Business and Professional Services,Business and Professional Services,Business and Professional Services,Business and Professional Services,Business and Professional Services
3. Name: category_level_2 | Data Type: VARCHAR | Sample values: Fuel Station,Automotive Retail,Fuel Station,Financial Service,Financial Service,Financial Service,Financial Service,Financial Service,Financial Service,Financial Service
4. Name: address | Data Type: VARCHAR | Sample values: Ground Floor,Ground Floor,Bagh Moria 

In [42]:
from langgraph.graph import START, StateGraph, END
from IPython.display import Image, display

graph_builder = StateGraph(State)
graph_builder.add_sequence([generate_sql_query])
graph_builder.add_edge(START, "generate_sql_query")
# graph_builder.add_edge("generate_answer", END)
graph = graph_builder.compile()


In [44]:
question = "Tell me the count for all the points that are related to manufacturing?"
for step in graph.stream(
    {
        "question": question,
        "query": "",
        "result": "",
        "answer": "",
        "table_info": table_info # Add table_info here
    },
    config={"configurable": {"thread_id": "1"}},
    stream_mode="updates"

):
    print(step)

print("-"*100,"\n")
print(step['generate_sql_query'])

TypeError: generate_sql_query() missing 1 required positional argument: 'table_info'

Columns:
1. Name: name | Data Type: VARCHAR | Sample values: Indane - Boham Gramin Vitrak,Gulf Carstop - Gogoi Automobile,Indane - Wangcha Gas Agency,Axis Bank,Axis Bank ATM,HDFC ERGO Insurance Agent: Partha P Bhattacharjee,HDFC ERGO Insurance Agent: Mintu Roy,HDFC Bank ATM,HDFC Bank,Punjab National Bank
2. Name: category_level_1 | Data Type: VARCHAR | Sample values: Travel and Transportation,Retail,Travel and Transportation,Business and Professional Services,Business and Professional Services,Business and Professional Services,Business and Professional Services,Business and Professional Services,Business and Professional Services,Business and Professional Services
3. Name: category_level_2 | Data Type: VARCHAR | Sample values: Fuel Station,Automotive Retail,Fuel Station,Financial Service,Financial Service,Financial Service,Financial Service,Financial Service,Financial Service,Financial Service
4. Name: address | Data Type: VARCHAR | Sample values: Ground Floor,Ground Floor,Bagh Moria 

## Initialize the chatbot

In [23]:


# Pre-prompt for NL-to-SQL (customize this!)
# Updated prompt to include required input variables for create_sql_query_chain
# and incorporate the S3_PATH as a partial variable
SQL_PROMPT_TEMPLATE = """
You are a SQL expert for querying Foursquare POI data using DuckDB. The data is in Parquet files at {s3_path}.

Based on the schema below and the user's question, generate a valid DuckDB SQL query.
Schema:
{table_info}

Contraints:
Always restrict the oulets aka POI withing INDIA.
If any other country is asked by the uers just say something like beyond scope and donot generate any sql prompt"

User question: {input}
Limit the results to {top_k}.

Generate a valid DuckDB SQL query to answer this. Always use read_parquet('{s3_path}') as the data source.
Do not add explanations; output only the SQL query.
Examples:
- Question: How many Starbucks Outlets are there in India?
  SQL: SELECT COUNT(*) as count FROM read_parquet('{s3_path}') WHERE LOWER(name) LIKE '%starbucks%' OR '%starbucks%' IN (SELECT LOWER(unnest(fsq_category_labels))));
"""

response_prompt = PromptTemplate.from_template(RESPONSE_TEMPLATE)

# Create the prompt template with the required input variables and partial variable for S3_PATH
sql_prompt = PromptTemplate(
    input_variables=["input", "top_k", "table_info"],
    template=SQL_PROMPT_TEMPLATE,
    partial_variables={"s3_path": DATA_PATH}
)

chain = sql_prompt | llm | StrOutputParser()

response = chain.invoke(
    {
        "input": "Show me the number of outlets in Gurgaon area?",
        "top_k": 10,
        "table_info": table_info
    }
)

response_chain = response_prompt | llm

response_chain

# Execute the generated SQL query
query_result = execute_sql(response, duckdb_conn)
print(query_result)


SELECT COUNT(*) as count 
FROM read_parquet('/content/output.geoparquet') 
WHERE LOWER(address) LIKE '%gurgaon%' 
LIMIT 10;

[(1097,)]


In [24]:
response = chain.invoke(
    {
        "input": "give me the number of outlets with category retail?",
        "top_k": 10,
        "table_info": table_info
    }
)

response = response.replace('```','').replace('sql','')
print(response)

# Execute the generated SQL query
query_result = execute_sql(response, duckdb_conn)
print(query_result)


SELECT COUNT(*) as count FROM read_parquet('/content/output.geoparquet') WHERE category_level_1 = 'Retail' LIMIT 10;

[(292123,)]
