<a href="https://colab.research.google.com/github/bhavyaJ-05/Text-to-SQL-generator/blob/main/15_Talk_to_Your_Data_Building_a_Natural_Language_to_SQL_Generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Text2SQL via Prompt Engineering

## Retrieve data

In [None]:
! curl "https://api.mockaroo.com/api/dde01370?count=1000&key=11149690" > "customers.csv"

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 99680    0 99680    0     0  61092      0 --:--:--  0:00:01 --:--:-- 61115


In [None]:
! curl "https://api.mockaroo.com/api/8ba6f630?count=1000&key=11149690" > "products.csv"

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  689k    0  689k    0     0   331k      0 --:--:--  0:00:02 --:--:--  331k


In [None]:
! curl "https://api.mockaroo.com/api/6fa67fe0?count=3000&key=11149690" > "orders.csv"

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  229k    0  229k    0     0  29509      0 --:--:--  0:00:07 --:--:-- 43363


## Setup database

In [None]:
import sqlite3
import pandas as pd
import os

In [None]:
# Define SQL schemas for creating tables
customers_schema = """
CREATE TABLE IF NOT EXISTS customers (
    customer_id INT PRIMARY KEY,
    first_name VARCHAR(50),
    last_name VARCHAR(50),
    email VARCHAR(50),
    phone_number VARCHAR(50),
    address VARCHAR(50),
    city VARCHAR(50),
    country VARCHAR(50),
    postal_code VARCHAR(50),
    loyalty_points INT
);
"""

products_schema = """
CREATE TABLE IF NOT EXISTS products (
    product_id INT PRIMARY KEY,
    product_name TEXT,
    description TEXT,
    price DECIMAL(10,2),
    discount_percentage DECIMAL(5,2),
    category VARCHAR(50),
    brand TEXT,
    stock_quantity INT,
    color VARCHAR(50),
    size VARCHAR(20),
    weight DECIMAL(5,2),
    dimensions TEXT,
    release_date DATE,
    rating DECIMAL(3,1),
    reviews_count INT,
    seller_name TEXT,
    seller_rating DECIMAL(3,1),
    seller_reviews_count INT,
    shipping_method VARCHAR(20),
    shipping_cost DECIMAL(6,2)
);
"""

orders_schema = """
CREATE TABLE IF NOT EXISTS orders (
    order_id INT PRIMARY KEY,
    customer_id INT,
    product_id INT,
    quantity INT,
    unit_price DECIMAL(10,2),
    total_price DECIMAL(10,2),
    order_date DATE,
    shipping_address VARCHAR(255),
    payment_method VARCHAR(20),
    status VARCHAR(20),
    FOREIGN KEY (customer_id) REFERENCES customers(customer_id),
    FOREIGN KEY (product_id) REFERENCES products(product_id)
);
"""

In [None]:
db_name = 'ecommerce.db'
if os.path.exists(db_name):
    os.remove(db_name)
    print(f"Removed existing database '{db_name}'.")

In [None]:
import sqlite3
import pandas as pd
import os



COLUMN_DATA_TYPES = {
    'customers': {
        'customer_id': 'int64',
        'first_name': 'object',
        'last_name': 'object',
        'email': 'object',
        'phone_number': 'object',
        'address': 'object',
        'city': 'object',
        'country': 'object',
        'postal_code': 'object',
        'loyalty_points': 'int64'
    },
    'products': {
        'product_id': 'int64',
        'product_name': 'object',
        'description': 'object',
        'price': 'float64',
        'discount_percentage': 'float64',
        'category': 'object',
        'brand': 'object',
        'stock_quantity': 'int64',
        'color': 'object',
        'size': 'object',
        'weight': 'float64',
        'dimensions': 'object',
        'release_date': 'datetime64[ns]',
        'rating': 'float64',
        'reviews_count': 'int64',
        'seller_name': 'object',
        'seller_rating': 'float64',
        'seller_reviews_count': 'int64',
        'shipping_method': 'object',
        'shipping_cost': 'float64'
    },
    'orders': {
        'order_id': 'int64',
        'customer_id': 'int64',
        'product_id': 'int64',
        'quantity': 'int64',
        'unit_price': 'float64',
        'total_price': 'float64',
        'order_date': 'datetime64[ns]',
        'shipping_address': 'object',
        'payment_method': 'object',
        'status': 'object'
    }
}

# --- Database setup ---
db_name = 'ecommerce.db'
conn = None  # Initialize connection to None

try:
    # Establish a connection to the SQLite database
    conn = sqlite3.connect(db_name)
    cursor = conn.cursor()
    print(f"Database '{db_name}' created and connected successfully. âœ…")

    # Create tables
    cursor.execute(customers_schema)
    cursor.execute(products_schema)
    cursor.execute(orders_schema)
    print("Tables 'customers', 'products', and 'orders' created successfully.")

    # --- Load data from CSV files into the tables using pandas ---
    csv_to_table_map = {
        '/content/customers.csv': 'customers',
        '/content/products.csv': 'products',
        '/content/orders.csv': 'orders'
    }

    for csv_file, table_name in csv_to_table_map.items():
        if os.path.exists(csv_file):
            print(f"\nProcessing '{csv_file}' for table '{table_name}'...")

            # Read the CSV file into a pandas DataFrame
            df = pd.read_csv(csv_file)

            # 1. Get the expected schema for the current table
            expected_schema = COLUMN_DATA_TYPES[table_name]
            expected_cols = list(expected_schema.keys())

            # 2. Handle missing/extra columns
            # Drop columns from DataFrame that are not in the schema
            df = df[df.columns.intersection(expected_cols)]

            # Add any missing columns and fill with None (which becomes NULL in SQL)
            for col in expected_cols:
                if col not in df.columns:
                    df[col] = None

            # 3. Reorder columns to match the defined schema exactly
            df = df[expected_cols]

            # 4. Enforce data types
            for col, dtype in expected_schema.items():
                if 'datetime' in dtype:
                    # Use pd.to_datetime for date/time columns, coercing errors to NaT (Not a Time)
                    df[col] = pd.to_datetime(df[col], errors='coerce')
                else:
                    # Use astype for other columns, handling potential conversion errors
                    try:
                        df[col] = df[col].astype(dtype)
                    except (ValueError, TypeError) as e:
                        print(f"  - Warning: Could not convert column '{col}' to {dtype}. Error: {e}. Leaving as is.")


            # Use the to_sql method to insert the cleaned DataFrame
            df.to_sql(table_name, conn, if_exists='append', index=False)
            print(f"  -> Data from '{csv_file}' loaded into '{table_name}' table successfully.")
        else:
            print(f"Warning: '{csv_file}' not found. Skipping data load for '{table_name}'.")

    # Commit the changes to the database
    conn.commit()
    print("\nData committed to the database successfully. ðŸŽ‰")

except sqlite3.Error as e:
    print(f"Database error: {e}")
except pd.errors.EmptyDataError as e:
    print(f"Pandas error: {e}. One of the CSV files might be empty.")
except KeyError as e:
    print(f"Schema definition error: A column is missing from the TABLE_DATA_TYPES dictionary: {e}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")
finally:
    # Close the connection if it was established
    if conn:
        conn.close()
        print("Database connection closed.")

Database 'ecommerce.db' created and connected successfully. âœ…
Tables 'customers', 'products', and 'orders' created successfully.

Processing '/content/customers.csv' for table 'customers'...
  -> Data from '/content/customers.csv' loaded into 'customers' table successfully.

Processing '/content/products.csv' for table 'products'...
  -> Data from '/content/products.csv' loaded into 'products' table successfully.

Processing '/content/orders.csv' for table 'orders'...
  -> Data from '/content/orders.csv' loaded into 'orders' table successfully.

Data committed to the database successfully. ðŸŽ‰
Database connection closed.


# setup your free API Key using Google's AI Studio

https://aistudio.google.com/


And, the key as `Secrets` in Colab.

### Install Gen AI library

We will be installing of the google-generativeai package, the official Python SDK for the Gemini API.

In [None]:
!pip install google-genai



### Import required modules

In [None]:
from google import genai
from google.colab import userdata

In [None]:
genai_client = genai.Client(api_key=userdata.get('GOOGLE_API_KEY'))

## PROMPT ENGINEERING


A fundamental advancement in prompt engineering is the realization that a prompt is not a monolithic question but a structured document composed of distinct components.



In [None]:
prompt = """

###ROLE###
You are a highly skilled Text-to-SQL translator with expertise in SQL syntax, database schema interpretation, and natural language understanding. You generate syntactically correct and semantically accurate SQL queries based on user input and a given database schema.

###CONTEXT###
The user is working with a relational database for an e-commerce platform. The database includes three main tables: `customers`, `products`, and `orders`. The goal is to allow users to input natural language queries (in English), and have the model return equivalent SQL statements that accurately extract the requested data using the given schema.

Here is the full schema:

**Customers Table**
```sql
CREATE TABLE IF NOT EXISTS customers (
    customer_id INT PRIMARY KEY,
    first_name VARCHAR(50),
    last_name VARCHAR(50),
    email VARCHAR(50),
    phone_number VARCHAR(50),
    address VARCHAR(50),
    city VARCHAR(50),
    country VARCHAR(50),
    postal_code VARCHAR(50),
    loyalty_points INT
);
````

**Products Table**

```sql
CREATE TABLE IF NOT EXISTS products (
    product_id INT PRIMARY KEY,
    product_name TEXT,
    description TEXT,
    price DECIMAL(10,2),
    discount_percentage DECIMAL(5,2),
    category VARCHAR(50),
    brand TEXT,
    stock_quantity INT,
    color VARCHAR(50),
    size VARCHAR(20),
    weight DECIMAL(5,2),
    dimensions TEXT,
    release_date DATE,
    rating DECIMAL(3,1),
    reviews_count INT,
    seller_name TEXT,
    seller_rating DECIMAL(3,1),
    seller_reviews_count INT,
    shipping_method VARCHAR(20),
    shipping_cost DECIMAL(6,2)
);
```

**Orders Table**

```sql
CREATE TABLE IF NOT EXISTS orders (
    order_id INT PRIMARY KEY,
    customer_id INT,
    product_id INT,
    quantity INT,
    unit_price DECIMAL(10,2),
    total_price DECIMAL(10,2),
    order_date DATE,
    shipping_address VARCHAR(255),
    payment_method VARCHAR(20),
    status VARCHAR(20),
    FOREIGN KEY (customer_id) REFERENCES customers(customer_id),
    FOREIGN KEY (product_id) REFERENCES products(product_id)
);
```

###TASK###
Your task is to:

1. Read a natural language query about the e-commerce data.
2. Interpret the user's intent based on the schema provided.
3. Generate a valid SQL `SELECT` query that returns the expected result.
4. Ensure correct table joins, column selection, filtering, and grouping as necessary.
5. Handle aggregate functions (e.g., `COUNT`, `AVG`, `SUM`) where appropriate.
6. Disambiguate user terms based on schema details (e.g., "buyer" â†’ `customers`, "product rating" â†’ `products.rating`, etc.).

###CONSTRAINTS###

* Only return a valid SQL query as output â€” no explanations or extra text.
* The user is using sqllite database - respond with correct and valid sqllite syntax
* Use aliases (`AS`) for column names only when the original name is ambiguous.
* Do not create or modify tables.
* Do not assume the existence of tables or columns not provided in the schema.
* Avoid subqueries unless absolutely necessary for correctness or performance.
* Prefer readability: indent joins and clauses properly.

###EXAMPLES###
**Input:** "Show me the names and emails of customers from Canada who have more than 1000 loyalty points."
**Output:**

```sql
SELECT first_name, last_name, email
FROM customers
WHERE country = 'Canada' AND loyalty_points > 1000;
```

**Input:** "List the top 5 products with the highest ratings and their categories."
**Output:**

```sql
SELECT product_name, category, rating
FROM products
ORDER BY rating DESC
LIMIT 5;
```

**Input:** "How many orders were placed in August 2025?"
**Output:**

```sql
SELECT COUNT(*) AS total_orders
FROM orders
WHERE order_date BETWEEN '2025-08-01' AND '2025-08-31';
```

**Input:** "What is the average shipping cost for products sold by sellers with a rating above 4.5?"
**Output:**

```sql
SELECT AVG(shipping_cost) AS average_shipping_cost
FROM products
WHERE seller_rating > 4.5;
```

###OUTPUT FORMAT###
Return only the sqllite SQL query as a code block using triple backticks and the `sql` language tag, like this:

```sql
-- Your SQL query here
```
"""

In [None]:
import json
def get_sql_query_via_gemini(genai_client, prompt, user_query):

  # https://www.geeksforgeeks.org/python/formatted-string-literals-f-strings-python/
  contents = f"""
  {prompt}

  Here's the user query in english you need to work on:
  {user_query}
  """
  response = genai_client.models.generate_content(model='gemini-2.5-flash', contents=contents)
  # print(response) # uncomment this and understand at the output

  # Access the usage_metadata attribute
  usage_metadata = response.usage_metadata

  # Print the different token counts
  print(f"Input Token Count: {usage_metadata.prompt_token_count}")
  print(f"Thoughts Token Count: {response.usage_metadata.thoughts_token_count}")
  print(f"Output Token Count: {usage_metadata.candidates_token_count}")
  print(f"Total Token Count: {usage_metadata.total_token_count}")

  output = response.text.replace('```sql', '').replace('```', '')

  return output


In [None]:
import sqlite3
import pandas as pd

def execute_query(query, db_name='ecommerce.db'):

    conn = None
    try:
        # Connect to the database
        conn = sqlite3.connect(db_name)
        cursor = conn.cursor()

        # Execute the query
        print(f"\nExecuting query on '{db_name}':\n{query}")
        cursor.execute(query)

        # Fetch all results
        results = cursor.fetchall()

        # Get column names from the cursor description
        columns = [description[0] for description in cursor.description]

        # Format results as a dataframe for easier use
        results_as_dict = [dict(zip(columns, row)) for row in results]
        results_df = pd.DataFrame(results_as_dict)

        print("Query executed successfully.")
        return results_df

    except sqlite3.Error as e:
        print(f"Database error executing query: {e}")
        return None
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return None
    finally:
        if conn:
            conn.close()

In [None]:
def text2sql(genai_client, prompt, user_query):
  output = get_sql_query_via_gemini(genai_client, prompt, user_query)
  results = execute_query(output)
  return results

In [None]:
text2sql(genai_client, prompt, "Show me the order count by country")

Input Token Count: 1136
Thoughts Token Count: 82
Output Token Count: 59
Total Token Count: 1277

Executing query on 'ecommerce.db':

SELECT
  c.country,
  COUNT(o.order_id) AS order_count
FROM orders AS o
JOIN customers AS c
  ON o.customer_id = c.customer_id
GROUP BY
  c.country;

Query executed successfully.


Unnamed: 0,country,order_count
0,Afghanistan,14
1,Albania,17
2,Antigua and Barbuda,3
3,Argentina,58
4,Armenia,8
...,...,...
117,Uruguay,7
118,Venezuela,2
119,Vietnam,24
120,Yemen,2


In [None]:
text2sql(genai_client, prompt, "What are my most popular products")

Input Token Count: 1135
Thoughts Token Count: 687
Output Token Count: 86
Total Token Count: 1908

Executing query on 'ecommerce.db':

SELECT
  p.product_name,
  SUM(o.quantity) AS total_quantity_sold
FROM orders AS o
JOIN products AS p
  ON o.product_id = p.product_id
GROUP BY
  p.product_id,
  p.product_name
ORDER BY
  total_quantity_sold DESC
LIMIT 5;

Query executed successfully.


Unnamed: 0,product_name,total_quantity_sold
0,vel nisl duis ac nibh fusce lacus purus alique...,539
1,in consequat ut nulla sed accumsan felis ut at...,448
2,curae nulla dapibus dolor vel est donec odio j...,422
3,lorem ipsum dolor sit amet consectetuer adipis...,401
4,quam pharetra magna ac consequat metus sapien ...,390


In [None]:
text2sql(genai_client, prompt, "which country ranks in middle by total sales??")

Input Token Count: 1138
Thoughts Token Count: 1582
Output Token Count: 161
Total Token Count: 2881

Executing query on 'ecommerce.db':

WITH CountrySales AS (
  SELECT
    c.country,
    SUM(o.total_price) AS total_sales
  FROM customers AS c
  JOIN orders AS o
    ON c.customer_id = o.customer_id
  GROUP BY
    c.country
),
RankedCountrySales AS (
  SELECT
    country,
    total_sales,
    ROW_NUMBER() OVER (ORDER BY total_sales DESC) AS sales_rank,
    COUNT(*) OVER () AS total_countries
  FROM CountrySales
)
SELECT
  country
FROM RankedCountrySales
WHERE
  sales_rank = ((total_countries - 1) / 2) + 1;

Query executed successfully.


Unnamed: 0,country
0,Belarus


In [None]:
text2sql(genai_client, prompt, "What is the 2nd highest sold product for each country?")

Input Token Count: 1142
Thoughts Token Count: 875
Output Token Count: 208
Total Token Count: 2225

Executing query on 'ecommerce.db':

WITH ProductSalesByCountry AS (
    SELECT
        c.country,
        p.product_name,
        SUM(o.quantity) AS total_sold_quantity
    FROM
        orders AS o
    JOIN
        customers AS c ON o.customer_id = c.customer_id
    JOIN
        products AS p ON o.product_id = p.product_id
    GROUP BY
        c.country,
        p.product_name
),
RankedProductSales AS (
    SELECT
        country,
        product_name,
        total_sold_quantity,
        ROW_NUMBER() OVER (PARTITION BY country ORDER BY total_sold_quantity DESC) AS rn
    FROM
        ProductSalesByCountry
)
SELECT
    country,
    product_name,
    total_sold_quantity
FROM
    RankedProductSales
WHERE
    rn = 2;

Query executed successfully.


Unnamed: 0,country,product_name,total_sold_quantity
0,Afghanistan,a nibh in quis justo maecenas rhoncus aliquam ...,96
1,Albania,vestibulum ante ipsum primis in faucibus orci ...,94
2,Antigua and Barbuda,magna vulputate luctus cum sociis natoque pena...,42
3,Argentina,consequat in consequat ut nulla sed accumsan f...,98
4,Armenia,proin eu mi nulla ac enim in tempor turpis nec...,43
...,...,...,...
113,Uruguay,tempus sit amet sem fusce consequat nulla nisl...,86
114,Venezuela,nulla pede ullamcorper augue a suscipit nulla ...,85
115,Vietnam,orci luctus et ultrices posuere cubilia curae ...,87
116,Yemen,morbi porttitor lorem id ligula suspendisse or...,41


In [None]:
text2sql(genai_client, prompt, "rank and count of sales of India by total sales??")

Input Token Count: 1140
Thoughts Token Count: 2632
Output Token Count: 161
Total Token Count: 3933

Executing query on 'ecommerce.db':

SELECT
  ranked_sales.country,
  ranked_sales.total_orders,
  ranked_sales.total_sales,
  ranked_sales.sales_rank
FROM (
  SELECT
    c.country,
    COUNT(o.order_id) AS total_orders,
    SUM(o.total_price) AS total_sales,
    RANK() OVER (ORDER BY SUM(o.total_price) DESC) AS sales_rank
  FROM orders AS o
  JOIN customers AS c
    ON o.customer_id = c.customer_id
  GROUP BY
    c.country
) AS ranked_sales
WHERE
  ranked_sales.country = 'India';

Query executed successfully.


In [None]:
text2sql(genai_client, prompt, "Give me the order count by day of month and sort it by order count")

Input Token Count: 1144
Thoughts Token Count: 217
Output Token Count: 60
Total Token Count: 1421

Executing query on 'ecommerce.db':

SELECT
  STRFTIME('%d', order_date) AS day_of_month,
  COUNT(order_id) AS order_count
FROM orders
GROUP BY
  day_of_month
ORDER BY
  order_count DESC;

Query executed successfully.


Unnamed: 0,day_of_month,order_count
0,7,124
1,27,119
2,1,113
3,9,110
4,28,108
5,22,106
6,25,104
7,24,102
8,19,102
9,12,102


In [None]:
text2sql(genai_client, prompt, "On which day of the week do I get the most orders? Give me a detailed report.")

Input Token Count: 1148
Thoughts Token Count: 397
Output Token Count: 133
Total Token Count: 1678

Executing query on 'ecommerce.db':

SELECT
  CASE strftime('%w', order_date)
    WHEN '0' THEN 'Sunday'
    WHEN '1' THEN 'Monday'
    WHEN '2' THEN 'Tuesday'
    WHEN '3' THEN 'Wednesday'
    WHEN '4' THEN 'Thursday'
    WHEN '5' THEN 'Friday'
    WHEN '6' THEN 'Saturday'
  END AS day_of_week,
  COUNT(order_id) AS total_orders
FROM orders
GROUP BY
  day_of_week
ORDER BY
  total_orders DESC;

Query executed successfully.


Unnamed: 0,day_of_week,total_orders
0,Wednesday,451
1,Monday,445
2,Sunday,441
3,Friday,435
4,Thursday,429
5,Tuesday,404
6,Saturday,395
