<a href="https://colab.research.google.com/github/melvinbaiju27/14proMaX/blob/main/n2sql_google_gemini_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Natural Language to SQL using Google's Gemini Pro | Python | Google AI Studio

[**Link to my YouTube Channel**](https://www.youtube.com/BhaveshBhatt8791?sub_confirmation=1)

Click on the link below to open a Colab version of the notebook. You will be able to create your own version.

<a href="https://colab.research.google.com/github/bhattbhavesh91/n2sql-google-gemini/blob/main/n2sql-google-gemini-notebook.ipynb" target="_blank"><img height="40" alt="Run your own notebook in Colab" src = "https://colab.research.google.com/assets/colab-badge.svg"></a>

# Installation

In [1]:
!pip install -q google-generativeai==0.3.1

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/146.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━[0m [32m92.2/146.6 kB[0m [31m3.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m146.6/146.6 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/598.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m593.9/598.7 kB[0m [31m28.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m598.7/598.7 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[?25h

# Imports

In [2]:
import google.generativeai as genai
from pathlib import Path
import sqlite3

# Version

In [3]:
genai.__version__

'0.3.1'

# Secret Key

In [5]:
from google.colab import userdata

genai.configure(api_key = userdata.get('GEMINI_KEY'))

# Configurations

In [6]:
# Set up the model
generation_config = {
  "temperature": 0.4,
  "top_p": 1,
  "top_k": 32,
  "max_output_tokens": 4096,
}

safety_settings = [
  {
    "category": "HARM_CATEGORY_HARASSMENT",
    "threshold": "BLOCK_MEDIUM_AND_ABOVE"
  },
  {
    "category": "HARM_CATEGORY_HATE_SPEECH",
    "threshold": "BLOCK_MEDIUM_AND_ABOVE"
  },
  {
    "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
    "threshold": "BLOCK_MEDIUM_AND_ABOVE"
  },
  {
    "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
    "threshold": "BLOCK_MEDIUM_AND_ABOVE"
  }
]

# Model Instance

In [7]:
model = genai.GenerativeModel(model_name = "gemini-pro",
                              generation_config = generation_config,
                              safety_settings = safety_settings)

# SQL Query Executor

In [8]:
def read_sql_query(sql, db):
    conn = sqlite3.connect(db)
    cur = conn.cursor()
    cur.execute(sql)
    rows = cur.fetchall()
    for row in rows:
        print(row)
    conn.close()

In [10]:
read_sql_query('SELECT * FROM fashion_products LIMIT 10;',
               "fashion_db.sqlite")

(1, 'T-Shirt', 'Clothing', 19.99, 100)
(2, 'Jeans', 'Clothing', 49.99, 50)
(3, 'Sneakers', 'Footwear', 79.99, 30)
(4, 'Jacket', 'Clothing', 99.99, 20)
(5, 'Watch', 'Accessories', 199.99, 10)
(6, 'Backpack', 'Accessories', 39.99, 40)
(7, 'Sunglasses', 'Accessories', 29.99, 60)
(8, 'Dress', 'Clothing', 59.99, 25)
(9, 'Sandals', 'Footwear', 29.99, 35)
(10, 'Hat', 'Accessories', 14.99, 70)


# Define Prompt

In [15]:
prompt_parts_1 = [
    """You are an expert in converting English questions to SQL code! The SQL database has the name `fashion_products` and has the following columns:
    - `id` (unique identifier for each product),
    - `product_name` (name of the product),
    - `category` (category of the product, e.g., Clothing, Footwear, Accessories),
    - `price` (price of the product),
    - `stock_quantity` (quantity of the product in stock).

    For example:
    Example 1 - How many products are in the Clothing category?
    ```
    SELECT COUNT(*) FROM fashion_products WHERE category = 'Clothing';
    ```

    Example 2 - What is the most expensive product?
    ```
    SELECT product_name FROM fashion_products WHERE price = (SELECT MAX(price) FROM fashion_products);
    ```

    Example 3 - List all products with a stock quantity less than 30.
    ```
    SELECT product_name FROM fashion_products WHERE stock_quantity < 30;
    ```

    Example 4 - What is the average price of products in the Footwear category?
    ```
    SELECT AVG(price) FROM fashion_products WHERE category = 'Footwear';
    ```

    Example 5 - How many products are priced above $50?
    ```
    SELECT COUNT(*) FROM fashion_products WHERE price > 50;
    ```

    **Rules:**
    1. Always use the correct table name (`fashion_products`) and column names.
    2. Only generate `SELECT` queries. Do not generate `INSERT`, `UPDATE`, `DELETE`, or `DROP` queries.
    3. Do not include ``` or \\n in the output.
    4. Ensure the SQL query is valid and optimized.

    Now, generate the SQL query for the following question:
    """
]

In [16]:
question = "Tell me the id of the most expensive T-shirt?"

In [17]:
prompt_parts = [prompt_parts_1[0], question]
response = model.generate_content(prompt_parts)
response.text

"```\nSELECT id FROM fashion_products WHERE product_name LIKE '%T-Shirt%' AND price = (SELECT MAX(price) FROM fashion_products WHERE product_name LIKE '%T-Shirt%');\n```"

In [19]:
read_sql_query("""SELECT id FROM fashion_products WHERE product_name LIKE '%T-Shirt%' AND price = (SELECT MAX(price) FROM fashion_products WHERE product_name LIKE '%T-Shirt%');
""",
               "fashion_db.sqlite")

(1,)


# Combine it into Function

In [30]:
import re
import sqlite3
from google.colab import userdata
import google.generativeai as genai

In [31]:
# Function to clean SQL query
def clean_sql_query(sql_query):
    # Remove backticks, newlines, and leading/trailing whitespace
    cleaned_query = re.sub(r'```sql|```|\n', '', sql_query).strip()
    return cleaned_query

# Function to execute SQL query
def read_sql_query(sql, db):
    # Clean the SQL query
    sql = clean_sql_query(sql)

    # Connect to the SQLite database
    conn = sqlite3.connect(db)
    cur = conn.cursor()

    try:
        # Execute the query
        cur.execute(sql)
        rows = cur.fetchall()
    except sqlite3.OperationalError as e:
        return f"Error executing SQL query: {e}"
    finally:
        # Close the connection
        conn.close()

    return rows

In [32]:
# Function to generate Gemini response
def generate_gemini_response(question, input_prompt):
    # Combine the prompt and question
    prompt_parts = [input_prompt, question]

    # Generate the SQL query using Gemini
    response = model.generate_content(prompt_parts)

    # Extract the generated SQL query
    generated_sql = response.text

    # Execute the SQL query and return the result
    output = read_sql_query(generated_sql, "fashion_db.sqlite")
    return output

In [33]:
# Example prompt
prompt_parts_1 = [
    """You are an expert in converting English questions to SQL code! The SQL database has the name `fashion_products` and has the following columns:
    - `id` (unique identifier for each product),
    - `product_name` (name of the product),
    - `category` (category of the product, e.g., Clothing, Footwear, Accessories),
    - `price` (price of the product),
    - `stock_quantity` (quantity of the product in stock).

    For example:
    Example 1 - How many products are in the Clothing category?
    SELECT COUNT(*) FROM fashion_products WHERE category = 'Clothing';

    Example 2 - What is the most expensive product?
    SELECT product_name FROM fashion_products WHERE price = (SELECT MAX(price) FROM fashion_products);

    Example 3 - List all products with a stock quantity less than 30.
    SELECT product_name FROM fashion_products WHERE stock_quantity < 30;

    Example 4 - What is the average price of products in the Footwear category?
    SELECT AVG(price) FROM fashion_products WHERE category = 'Footwear';

    Example 5 - How many products are priced above $50?
    SELECT COUNT(*) FROM fashion_products WHERE price > 50;

    **Rules:**
    1. Always use the correct table name (`fashion_products`) and column names.
    2. Only generate `SELECT` queries. Do not generate `INSERT`, `UPDATE`, `DELETE`, or `DROP` queries.
    3. Do not include ``` or \\n in the output.
    4. Ensure the SQL query is valid and optimized.

    Now, generate the SQL query for the following question:
    """
]

In [34]:
# Example usage
question = "What is the most expensive product?"
output = generate_gemini_response(question, prompt_parts_1[0])
print("Generated SQL Query Result:", output)

Generated SQL Query Result: [('Watch',)]
