<a href="https://colab.research.google.com/github/kavyajeetbora/foursquare_ai/blob/master/notebooks/07_duckdb_ai_bot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup the Python Environment

In [None]:
!pip install --quiet duckdb jupysql duckdb-engine
## Langchain Framework
!pip install --quiet langchain langchain-community langchain-openai langgraph "langchain[openai]"

In [None]:
import duckdb
import os
import pandas as pd
from dotenv import load_dotenv
from google.colab import files

from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate

from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

# Check if .env file exists, if not, prompt for upload
if not os.path.exists(".env"):
    uploaded = files.upload()
    if ".env" not in uploaded:
        print("Please upload a .env file to proceed.")
    else:
        with open(".env", "wb") as f:
            f.write(uploaded[".env"])
        load_dotenv(dotenv_path = ".env", override=True)
else:
    load_dotenv(dotenv_path = ".env", override=True)

In [None]:
os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGSMITH_PROJECT"] = "sql_agent_project"
os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY") # Ensure you have a way to get this

## Data Engineering Part

- Here we will prepare the data for our analysis
- Clean Data
- Load it to local storage for faster loading and querying

In [None]:
# Initialize DuckDB connection
con = duckdb.connect()

# Load required extensions
con.execute("INSTALL httpfs; LOAD httpfs; INSTALL spatial; LOAD spatial;")

s3_places_path = 's3://fsq-os-places-us-east-1/release/dt=2025-09-09/places/parquet/places-*.zstd.parquet'
s3_categories_path = 's3://fsq-os-places-us-east-1/release/dt=2025-09-09/categories/parquet/categories.zstd.parquet'

# Execute the SELECT query and create a view
con.execute(f"""
CREATE OR REPLACE VIEW places_with_categories AS
WITH places AS (
    SELECT
        DISTINCT UNNEST(P.fsq_category_ids) as fsq_category_id,
        name,
        postcode,
        address,
        region,
        ST_Point(longitude, latitude) AS geom
    FROM read_parquet('{s3_places_path}') AS P
    WHERE latitude IS NOT NULL AND longitude IS NOT NULL AND country='IN'
),
places_with_categories AS (
    SELECT
        P.name AS name,
        C.level1_category_name AS category_level_1,
        C.level2_category_name AS category_level_2,
        postcode,
        address,
        region,
        P.geom
    FROM places AS P
    JOIN read_parquet('{s3_categories_path}') AS C
    ON P.fsq_category_id = C.category_id
)
SELECT
    name,
    category_level_1,
    category_level_2,
    address,
    region,
    postcode,
    geom
FROM places_with_categories;
""")

# Export the view to GeoParquet
con.execute("COPY (SELECT * FROM places_with_categories) TO 'output.geoparquet' WITH (FORMAT PARQUET, CODEC ZSTD);")

## Check the total count of the database:
# con.execute("SELECT COUNT(*) FROM places_with_categories;")
# result = con.fetchone()[0]
# print(result)
## Around 1358392 points are there

# Close the connection
con.close()

# AI ChatBot

-

## Helper functions

- Setting up connection with duckdb
- Then generating shema details for all the tables in the database

In [None]:
def execute_sql(sql_query: str, con):
    try:
        # Replace {s3_path} if not already in the query
        result = con.execute(sql_query).fetchall()
        return result
    except Exception as e:
        return f"Error executing SQL: {str(e)}"

def get_duckdb_connection():
    con = duckdb.connect(database=':memory:')  # In-memory for simplicity; use a file path for persistence if needed
    con.execute("INSTALL httpfs;")
    con.execute("LOAD httpfs;")
    # Optional: Set S3 region if needed (public bucket, so usually not required)
    # con.execute("SET s3_region='us-east-1';")
    return con

def get_db_schema(DATA_PATH, columns, duckdb_conn, limit=5):

    data_schema = f"""Columns:\n"""

    sql_query = f"SELECT {",".join(columns)} FROM read_parquet('{DATA_PATH}') WHERE 1=1"

    for column in columns:
        sql_query += f" AND {column} IS NOT NULL"

    sql_query += f" LIMIT {limit};"
    sample_result = execute_sql(sql_query, duckdb_conn)

    schema_details = execute_sql(f'DESCRIBE {sql_query}', duckdb_conn)

    for i, column in enumerate(columns):
        data_type = schema_details[i][1]
        sample_values = ",".join([str(r[i]) for r in sample_result])
        data_schema += f"{i+1}. Name: {column} | Data Type: {data_type} | Sample values: {sample_values}"
        data_schema += "\n"
    return data_schema

In [None]:
duckdb_conn = get_duckdb_connection()

DATA_PATH = "/content/output.geoparquet"
columns = ['name', 'category_level_1', 'category_level_2', 'address', 'region', 'postcode']

table_info = get_db_schema(DATA_PATH, columns, duckdb_conn, limit=10)
print(table_info)

## Initialize the chatbot

In [None]:
# Initialize LLM (replace with your API key)
llm = ChatOpenAI(model="gpt-4o-mini", temperature=1)  # Or ChatGroq(model="grok-beta", api_key="YOUR_XAI_API_KEY")

# Pre-prompt for NL-to-SQL (customize this!)
# Updated prompt to include required input variables for create_sql_query_chain
# and incorporate the S3_PATH as a partial variable
SQL_PROMPT_TEMPLATE = """
You are a SQL expert for querying Foursquare POI data using DuckDB. The data is in Parquet files at {s3_path}.
The schema includes columns: name (string), country (string, e.g., 'IN' for India), location (struct with lat/lng), categories (array), etc.

Based on the schema below and the user's question, generate a valid DuckDB SQL query.
Schema:
{table_info}

Contraints:
Always restrict the oulets aka POI withing INDIA so always use country = 'IN' in WHERE clause in all generated SQL queries.
If any country is asked by the uers just say something like beyond scope and donot generate any sql prompt"

User question: {input}
Limit the results to {top_k}.

Generate a valid DuckDB SQL query to answer this. Always use read_parquet('{s3_path}') as the data source.
Do not add explanations; output only the SQL query.
Examples:
- Question: How many Starbucks Outlets are there in India?
  SQL: SELECT COUNT(*) as count FROM read_parquet('{s3_path}') WHERE LOWER(name) LIKE '%starbucks%' OR '%starbucks%' IN (SELECT LOWER(unnest(fsq_category_labels))))
AND country = 'IN';
"""

# Create the prompt template with the required input variables and partial variable for S3_PATH
sql_prompt = PromptTemplate(
    input_variables=["input", "top_k", "table_info"],
    template=SQL_PROMPT_TEMPLATE,
    partial_variables={"s3_path": DATA_PATH}
)

# # Now create the chain using the updated prompt
# sql_chain = create_sql_query_chain(llm=llm, db=db, prompt=sql_prompt)

chain = sql_prompt

response = chain.invoke(
    {
        "input": "Show me the number of outlets in Gurgaon area?",
        "top_k": 10,
        "table_info": table_info
    }
)

In [None]:
response