<a href="https://colab.research.google.com/github/kavyajeetbora/foursquare_ai/blob/master/notebooks/07_duckdb_ai_bot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup the Python Environment

In [None]:
!pip install --quiet duckdb jupysql duckdb-engine
## Langchain Framework
!pip install --quiet langchain langchain-community langchain-openai langgraph "langchain[openai]"

In [None]:
import duckdb
import os
import pandas as pd
from dotenv import load_dotenv
from google.colab import files

from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate

from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

from pydantic import BaseModel, Field, field_validator
import json
from langchain.prompts import ChatPromptTemplate


# Check if .env file exists, if not, prompt for upload
if not os.path.exists(".env"):
    uploaded = files.upload()
    if ".env" not in uploaded:
        print("Please upload a .env file to proceed.")
    else:
        with open(".env", "wb") as f:
            f.write(uploaded[".env"])
        load_dotenv(dotenv_path = ".env", override=True)
else:
    load_dotenv(dotenv_path = ".env", override=True)

In [None]:
os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGSMITH_PROJECT"] = "sql_agent_project"
os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY") # Ensure you have a way to get this

## Data Engineering Part

- Here we will prepare the data for our analysis
- Clean Data
- Load it to local storage for faster loading and querying

In [None]:
# Initialize DuckDB connection
con = duckdb.connect()

# Load required extensions
con.execute("INSTALL httpfs; LOAD httpfs; INSTALL spatial; LOAD spatial;")

s3_places_path = 's3://fsq-os-places-us-east-1/release/dt=2025-09-09/places/parquet/places-*.zstd.parquet'
s3_categories_path = 's3://fsq-os-places-us-east-1/release/dt=2025-09-09/categories/parquet/categories.zstd.parquet'

# Execute the SELECT query and create a view
con.execute(f"""
CREATE OR REPLACE VIEW places_with_categories AS
WITH places AS (
    SELECT
        DISTINCT UNNEST(P.fsq_category_ids) as fsq_category_id,
        name,
        postcode,
        address,
        region,
        ST_Point(longitude, latitude) AS geom
    FROM read_parquet('{s3_places_path}') AS P
    WHERE latitude IS NOT NULL AND longitude IS NOT NULL AND country='IN'
),
places_with_categories AS (
    SELECT
        P.name AS name,
        C.level1_category_name AS category_level_1,
        C.level2_category_name AS category_level_2,
        postcode,
        address,
        region,
        P.geom
    FROM places AS P
    JOIN read_parquet('{s3_categories_path}') AS C
    ON P.fsq_category_id = C.category_id
)
SELECT
    name,
    category_level_1,
    category_level_2,
    address,
    region,
    postcode,
    geom
FROM places_with_categories;
""")

# Export the view to GeoParquet
con.execute("COPY (SELECT * FROM places_with_categories) TO 'output.geoparquet' WITH (FORMAT PARQUET, CODEC ZSTD);")

## Check the total count of the database:
# con.execute("SELECT COUNT(*) FROM places_with_categories;")
# result = con.fetchone()[0]
# print(result)
## Around 1358392 points are there

# Close the connection
con.close()

In [None]:
## Check the total count of the database:

example_query = '''
SELECT COUNT(*) as count FROM read_parquet('/content/output.geoparquet') WHERE (LOWER(name) LIKE '%manufacturing%' OR LOWER(category_level_1) LIKE '%manufacturing%' OR LOWER(category_level_2) LIKE '%manufacturing%' OR LOWER(address) LIKE '%manufacturing%') LIMIT 10;
'''
# example_query = '''
# SELECT name, category_level_1, category_level_2, address, region, postcode
# FROM read_parquet('/content/output.geoparquet')
# WHERE (LOWER(name) LIKE '%coffee%'
# OR LOWER(category_level_1) LIKE '%coffee%'
# OR LOWER(category_level_2) LIKE '%coffee%'
# OR LOWER(address) LIKE '%coffee%')
# AND (LOWER(region) LIKE '%bangalore%' OR LOWER(address) LIKE '%bangalore%')
# ORDER BY CASE
# WHEN LOWER(name) LIKE '%coffee%' THEN 1
# WHEN LOWER(category_level_2) LIKE '%coffee%' THEN 2
# WHEN LOWER(category_level_1) LIKE '%coffee%' THEN 3
# ELSE 4 END
# LIMIT 10;
# '''

con = duckdb.connect()
result = con.execute(example_query).df()
con.close()
## Around 1358392 points are there

result.head()

# AI Chatbot for DuckDB Queries: Concise Overview

- **Core Objective**: Build an AI chatbot that converts natural language questions into DuckDB SQL queries on Parquet data (e.g., POI datasets), executes them, and generates conversational responses.
- **Tech Stack**: OpenAI LLM (GPT-4o-mini) for SQL/answer generation; LangGraph for stateful workflow orchestration; DuckDB for efficient Parquet querying.
- **Pipeline Flow**:
    1. LLM generates SQL from question + schema;
    2. Execute query and format results as JSON;
    3. LLM synthesizes witty, concise answer from question, SQL, and results.
    
- **Key Implementation**: Pydantic `State` for propagation; Graph nodes for gen/execute/answer with linear edges; Auto-schema from Parquet (columns/types/samples) for prompts.
- **Usage & Next**: Invoke via `graph.invoke(State(question="..."))`; Extend with geospatial filters, multi-turn history, or error retries.

## Setup Input and Output Structured Format


In [None]:
class State(BaseModel):
    question: str = Field(
        default="",
        description="The natural language question or input provided by the user, intended to be translated into a DuckDB SQL query. Example: 'Show me all customers with orders above $100.'"
    )
    data_path: str = Field(
        default="",
        description="This is the data source of the table from where we will be querying using duckdb. This is a parquet file"
    )
    query: str = Field(
        default="",
        description="The generated DuckDB SQL query derived from the user's natural language question. Must conform to DuckDB's SQL syntax and conventions. Example: 'SELECT * FROM customers WHERE order_total > 100;'"
    )
    result: str = Field(
        default="",
        description="The output or result set returned by executing the DuckDB SQL query against the database, formatted as a JSON string. Example: '[{\"id\": 1, \"name\": \"Alice\", \"order_total\": 150}]'"
    )
    answer: str = Field(
        default="",
        description="The final natural language response generated for the user, summarizing or explaining the query results in a conversational manner. Example: 'Here are the customers with orders above $100: Alice with $150.'"
    )

    @field_validator("result", mode="after")
    @classmethod
    def validate_result(cls, result: str) -> str:
        try:
            json.loads(result)
        except json.JSONDecodeError:
            raise ValueError("Result must be a valid JSON string")
        return result

class QueryOutput(BaseModel):
    query: str = Field(..., description="Syntactically valid DuckDB SQL query")

    @field_validator("query", mode="after")
    @classmethod
    def validate_query(cls, query: str) -> str:
        if not query.strip().endswith(";"):
            raise ValueError("DuckDB query must end with a semicolon")
        return query


## Helper functions

- Setting up connection with duckdb
- Then generating shema details for all the tables in the database
- Define the input and the output structure format for all the parameters are going to be passed on to the llm

Setup the system prompt template for the AI chat bot:

In [None]:
# Prompt Setup (unchanged)
system_message = """
Given an input question, create a syntactically correct DuckDB SQL query to answer it using data from Parquet files at {data_path}.
Limit results to {top_k} unless the user specifies a different number, ordering by a relevant column for the most interesting results.
Use only columns: `name`, `category_level_1`, `category_level_2`, `address`, `region`, `postcode`.
For keyword searches (e.g., "temple," "hotel"),
search across `name`, `category_level_1`, `category_level_2`, and `address` using LOWER() and LIKE, prioritizing matches in `name` and `category_level_2` via ORDER BY CASE.
For region-specific queries, check `region` and `address` columns.

Restrict outlets (POIs) to India only.
If a non-India country is mentioned, respond with "Query beyond scope, restricted to India" and do not generate a query.

Use read_parquet('{data_path}') as the data source.

Schema: {table_info}
User question: {input}
Limit: {top_k}
Output only the Duckdb query without explanations.
Strictly stick to the above mentioned schema only.

Few Shot Examples:
- Question: How many temples are there in India?
- Query: SELECT COUNT(*) as count FROM read_parquet('{data_path}') WHERE (LOWER(name) LIKE '%temple%' OR LOWER(category_level_1) LIKE '%temple%' OR LOWER(category_level_2) LIKE '%temple%' OR LOWER(address) LIKE '%temple%') LIMIT {top_k};

- Question: Which hotels are located in Goa?
- Query: SELECT name, category_level_1, category_level_2, address, region, postcode FROM read_parquet('{data_path}') WHERE (LOWER(name) LIKE '%hotel%' OR LOWER(category_level_1) LIKE '%hotel%' OR LOWER(category_level_2) LIKE '%hotel%' OR LOWER(address) LIKE '%hotel%') AND (LOWER(region) LIKE '%goa%' OR LOWER(address) LIKE '%goa%') ORDER BY CASE WHEN LOWER(name) LIKE '%hotel%' THEN 1 WHEN LOWER(category_level_2) LIKE '%hotel%' THEN 2 WHEN LOWER(category_level_1) LIKE '%hotel%' THEN 3 ELSE 4 END LIMIT {top_k};

- Question: List petrol pumps in Chennai.
- Query: SELECT name, category_level_1, category_level_2, address, region, postcode FROM read_parquet('{data_path}') WHERE (LOWER(name) LIKE '%petrol%' OR LOWER(category_level_1) LIKE '%petrol%' OR LOWER(category_level_2) LIKE '%petrol%' OR LOWER(address) LIKE '%petrol%') AND (LOWER(region) LIKE '%chennai%' OR LOWER(address) LIKE '%chennai%') ORDER BY CASE WHEN LOWER(name) LIKE '%petrol%' THEN 1 WHEN LOWER(category_level_2) LIKE '%petrol%' THEN 2 WHEN LOWER(category_level_1) LIKE '%petrol%' THEN 3 ELSE 4 END LIMIT {top_k};

- Question: How many restaurants are in Bangalore?
- Query: SELECT COUNT(*) as count FROM read_parquet('{data_path}') WHERE (LOWER(name) LIKE '%restaurant%' OR LOWER(category_level_1) LIKE '%restaurant%' OR LOWER(category_level_2) LIKE '%restaurant%' OR LOWER(address) LIKE '%restaurant%') AND (LOWER(region) LIKE '%bangalore%' OR LOWER(address) LIKE '%bangalore%') LIMIT {top_k};

- Question: Find bookstores in Delhi.
- Query: SELECT name, category_level_1, category_level_2, address, region, postcode FROM read_parquet('{data_path}') WHERE (LOWER(name) LIKE '%bookstore%' OR LOWER(category_level_1) LIKE '%bookstore%' OR LOWER(category_level_2) LIKE '%bookstore%' OR LOWER(address) LIKE '%bookstore%') AND (LOWER(region) LIKE '%delhi%' OR LOWER(address) LIKE '%delhi%') ORDER BY CASE WHEN LOWER(name) LIKE '%bookstore%' THEN 1 WHEN LOWER(category_level_2) LIKE '%bookstore%' THEN 2 WHEN LOWER(category_level_1) LIKE '%bookstore%' THEN 3 ELSE 4 END LIMIT {top_k};
"""
user_prompt = "Question: {input}"
query_prompt_template = ChatPromptTemplate.from_messages([("system", system_message), ("user", user_prompt)])

Create a ChatBot Class

In [None]:
class FourSquareChatBot:
    """NLP-to-SQL chatbot for querying DuckDB databases, optimized for Parquet files and FourSquare data."""

    def __init__(self, data_path: str, columns: list[str], llm, database: str = ":memory:"):
        """
        Initialize the chatbot with a DuckDB connection and schema.

        Args:
            data_path (str): Path to the Parquet file (e.g., S3 URL or local path).
            columns (list[str]): List of column names to include in the schema.
            llm: Language model instance for generating SQL queries and answers.
            database (str): DuckDB database path (default: ':memory:' for in-memory).
        """
        self.data_path = data_path
        self.columns = columns
        self.llm = llm
        self.conn = self._get_duckdb_connection(database)
        self.table_info = self._get_db_schema(limit=5)

    def _get_duckdb_connection(self, database: str) -> duckdb.DuckDBPyConnection:
        """Create and configure a DuckDB connection."""
        con = duckdb.connect(database=database)
        con.execute("INSTALL httpfs;")
        con.execute("LOAD httpfs;")
        con.execute("INSTALL spatial;")
        con.execute("LOAD spatial;")
        # Optional: Set S3 region if needed
        # con.execute("SET s3_region='us-east-1';")
        return con

    def _get_db_schema(self, limit=5):

        """Generate schema information from the Parquet file with sample values."""

        data_schema = f"""Columns:\n"""

        sql_query = f"SELECT {",".join(self.columns)} FROM read_parquet('{self.data_path}') WHERE 1=1"

        for column in self.columns:
            sql_query += f" AND {column} IS NOT NULL"

        sql_query += f" LIMIT {limit};"
        sample_result = self._execute_sql(sql_query)['result']

        schema_details = self._execute_sql(f'DESCRIBE {sql_query}')['result']

        for i, column in enumerate(self.columns):
            data_type = schema_details[i][1]
            sample_values = ",".join([str(r[i]) for r in sample_result])
            data_schema += f"{i+1}. Name: {column} | Data Type: {data_type} | Sample values: {sample_values}"
            data_schema += "\n"
        return data_schema

    def _execute_sql(self, sql_query: str) -> list:
        """Execute a SQL query and return results."""
        try:
            result = self.conn.execute(sql_query).fetchall()

            columns = [desc[0] for desc in self.conn.description]  # Grab column names
            formatted_result = [dict(zip(columns, row)) for row in result]

            return {"result": result, "error": None}
        except Exception as e:
            return f"Error executing SQL: {str(e)}"

    def generate_sql_query(self, state: State) -> QueryOutput:
        """Generate a DuckDB SQL query from the user's question."""
        prompt = query_prompt_template.invoke(
            {
                "top_k": 10,
                "table_info": self.table_info,
                "input": state.question,
                "data_path": self.data_path
            }
        )
        response = self.llm.invoke(prompt)
        cleaned_query = response.content.strip()

        if cleaned_query.startswith("```sql"):
            cleaned_query = cleaned_query[6:]
        if cleaned_query.endswith("```"):
            cleaned_query = cleaned_query[:-3]
        # if not cleaned_query.endswith(";"):
        #      cleaned_query += ";"
        return QueryOutput(query=cleaned_query.strip())

    def generate_answer(self, state: State) -> dict:
        """Generate a conversational answer using query results."""
        # Execute the query and store JSON result

        result = self._execute_sql(state.query)

        if isinstance(result, str) and result.startswith("Error"):
            state.result = json.dumps({"error": result})
            state.answer = f"Sorry, I couldn't process your query due to an error: {result}"
        else:
            state.result = result['result']

            # Generate conversational answer
            prompt = (
                "Given the following user question, corresponding SQL query, "
                "and SQL result, answer the user question in a conversational manner.\n\n"
                f"Question: {state.question}\n"
                f"SQL Query: {state.query}\n"
                f"SQL Result: {state.result}"
            )
            response = self.llm.invoke(prompt)
            state.answer = response.content

        return {"state": state}

    def process_question(self, question: str) -> dict:
        """Process a user question end-to-end and return the updated state."""
        state = State(question=question)
        query_output = self.generate_sql_query(state)
        state.query = query_output.query
        result = self.generate_answer(state)
        return result

    def __del__(self):
        """Clean up by closing the DuckDB connection."""
        if hasattr(self, "conn"):
            self.conn.close()

The Pipeline here would be

write_query -> execute_query -> generate_answer

In [239]:
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

fsq_chat_bot = FourSquareChatBot(
    data_path = "/content/output.geoparquet",
    columns = ['name', 'category_level_1', 'category_level_2', 'address', 'region', 'postcode'],
    llm = llm
)

## Process Query

# query = "Tell me the count for all the points that are related to manufacturing?"
query = "List down some coffee shops in Jodhpur?"
print(fsq_chat_bot.process_question(question=query)['state'].answer)

Sure! Here are some coffee shops you can check out in Jodhpur:

1. **Café Coffee Day**
   - **Category:** Cafe, Coffee, and Tea House
   - **Address:** OLD DHANIA HOUSE - JODHPUR, RAJASTHAN

2. **Abar Boithak - The Coffee Shop**
   - **Category:** Cafe, Coffee, and Tea House
   - **Address:** 282, Jodhpur Pk, West Bengal, 700068

3. **Café Coffee Day**
   - **Category:** Cafe, Coffee, and Tea House
   - **Address:** ANSAL PLAZA – JODHPUR, RAJASTHAN

4. **Blue Mug Coffee and Thoughts**
   - **Category:** Cafe, Coffee, and Tea House
   - **Address:** 1/12, Jodhpur Park Road, Dhakuria, Selimpur, 1/121 Jodhpur Park, West Bengal, 700068

5. **Cha Coffee Jalkhabar**
   - **Category:** Cafe, Coffee, and Tea House
   - **Address:** Jodhpur Pk, West Bengal, 700068

6. **Blue Tokai Coffee Roasters**
   - **Category:** Cafe, Coffee, and Tea House
   - **Address:** 26, Jodhpur Pk, West Bengal, 700068

7. **Rana Snacketeria**
   - **Category:** Cafe, Coffee, and Tea House
   - **Address:** 394A, Jo