<a href="https://colab.research.google.com/github/maxsop/R-junk/blob/master/n2sql-google-gemini-notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Natural Language to SQL using Google's Gemini Pro

<a href="https://colab.research.google.com/github/bhattbhavesh91/n2sql-google-gemini/blob/main/n2sql-google-gemini-notebook.ipynb" target="_blank"><img height="40" alt="Run your own notebook in Colab" src = "https://colab.research.google.com/assets/colab-badge.svg"></a>

# Installation

In [1]:
!pip install -q google-generativeai==0.3.1

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/146.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━[0m [32m81.9/146.6 kB[0m [31m2.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m146.6/146.6 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/598.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m593.9/598.7 kB[0m [31m22.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m598.7/598.7 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25h

# Imports

In [19]:
import google.generativeai as genai
from pathlib import Path
import sqlite3
import pandas as pd

# Version

In [15]:
genai.__version__

'0.3.1'

# Secret Key

In [5]:
from google.colab import userdata

genai.configure(api_key = userdata.get('GEMINI_KEY'))

# Configurations

In [14]:
# Set up the model
generation_config = {
  "temperature": 0.4,
  "top_p": 1,
  "top_k": 32,
  "max_output_tokens": 4096,
}

safety_settings = [
  {
    "category": "HARM_CATEGORY_HARASSMENT",
    "threshold": "BLOCK_MEDIUM_AND_ABOVE"
  },
  {
    "category": "HARM_CATEGORY_HATE_SPEECH",
    "threshold": "BLOCK_MEDIUM_AND_ABOVE"
  },
  {
    "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
    "threshold": "BLOCK_MEDIUM_AND_ABOVE"
  },
  {
    "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
    "threshold": "BLOCK_MEDIUM_AND_ABOVE"
  }
]

# Model Instance

In [7]:
model = genai.GenerativeModel(model_name = "gemini-pro",
                              generation_config = generation_config,
                              safety_settings = safety_settings)

# Convert pandas to sql db

In [45]:
df = pd.read_csv("archive.zip")
df.head()

Unnamed: 0,index,url,name,sku,selling_price,original_price,currency,availability,color,category,...,source_website,breadcrumbs,description,brand,images,country,language,average_rating,reviews_count,crawled_at
0,0,https://www.adidas.com/us/beach-shorts/FJ5089....,Beach Shorts,FJ5089,40,,USD,InStock,Black,Clothing,...,https://www.adidas.com,Women/Clothing,Splashing in the surf. Making memories with yo...,adidas,"https://assets.adidas.com/images/w_600,f_auto,...",USA,en,4.5,35,2021-10-23 17:50:17.331255
1,1,https://www.adidas.com/us/five-ten-kestrel-lac...,Five Ten Kestrel Lace Mountain Bike Shoes,BC0770,150,,USD,InStock,Grey,Shoes,...,https://www.adidas.com,Women/Shoes,Lace up and get after it. The Five Ten Kestrel...,adidas,"https://assets.adidas.com/images/w_600,f_auto,...",USA,en,4.8,4,2021-10-23 17:50:17.423830
2,2,https://www.adidas.com/us/mexico-away-jersey/G...,Mexico Away Jersey,GC7946,70,,USD,InStock,White,Clothing,...,https://www.adidas.com,Kids/Clothing,"Clean and crisp, this adidas Mexico Away Jerse...",adidas,"https://assets.adidas.com/images/w_600,f_auto,...",USA,en,4.9,42,2021-10-23 17:50:17.530834
3,3,https://www.adidas.com/us/five-ten-hiangle-pro...,Five Ten Hiangle Pro Competition Climbing Shoes,FV4744,160,,USD,InStock,Black,Shoes,...,https://www.adidas.com,Five Ten/Shoes,The Hiangle Pro takes on the classic shape of ...,adidas,"https://assets.adidas.com/images/w_600,f_auto,...",USA,en,3.7,7,2021-10-23 17:50:17.615054
4,4,https://www.adidas.com/us/mesh-broken-stripe-p...,Mesh Broken-Stripe Polo Shirt,GM0239,65,,USD,InStock,Blue,Clothing,...,https://www.adidas.com,Men/Clothing,Step up to the tee relaxed. This adidas golf p...,adidas,"https://assets.adidas.com/images/w_600,f_auto,...",USA,en,4.7,11,2021-10-23 17:50:17.702680


In [27]:
df.shape

(845, 21)

In [31]:
df.columns

Index(['index', 'url', 'name', 'sku', 'selling_price', 'original_price',
       'currency', 'availability', 'color', 'category', 'source',
       'source_website', 'breadcrumbs', 'description', 'brand', 'images',
       'country', 'language', 'average_rating', 'reviews_count', 'crawled_at'],
      dtype='object')

In [46]:
df_usd = df[(df['currency'] == 'USD') & (df['selling_price'] > 0) & (df['language'] == 'en')]
df_usd.shape
df.shape

(845, 21)

In [47]:
df = df.drop(columns=['index','url', 'sku', 'crawled_at','original_price','images'], axis=1)

In [35]:
df.shape

(845, 17)

In [48]:
# Rename columns
df = df.rename(columns={
    'name': 'product_name',
    'selling_price': 'price',
    'brand': 'brand_name',
    'description': 'product_desc',
    'category': 'category_name'
})
df.columns


Index(['product_name', 'price', 'currency', 'availability', 'color',
       'category_name', 'source', 'source_website', 'breadcrumbs',
       'product_desc', 'brand_name', 'country', 'language', 'average_rating',
       'reviews_count'],
      dtype='object')

In [51]:
conn = sqlite3.connect('fashion_db.sqlite')
c = conn.cursor()

c.execute('CREATE TABLE IF NOT EXISTS fashion_products (product_name text, brand_name text, category text, price int, rating float, color text, product_desc text, country text, reviews_count int, source text)')
conn.commit()

df.to_sql('fashion_products', conn, if_exists='replace', index = False)


845

In [53]:
c.execute('''
SELECT product_name FROM fashion_products LIMIT 10
          ''')

for row in c.fetchall():
    print (row)
c.close()

('Beach Shorts',)
('Five Ten Kestrel Lace Mountain Bike Shoes',)
('Mexico Away Jersey',)
('Five Ten Hiangle Pro Competition Climbing Shoes',)
('Mesh Broken-Stripe Polo Shirt',)
('EQT Spikeless Golf Shoes',)
('Adicross Hybrid Shorts',)
('Tiro 21 Windbreaker',)
('Classic 3-Stripes Swimsuit',)
('Tiro 21 Windbreaker',)


# SQL Query Executor

In [8]:
def read_sql_query(sql, db):
    conn = sqlite3.connect(db)
    cur = conn.cursor()
    cur.execute(sql)
    rows = cur.fetchall()
    for row in rows:
        print(row)
    conn.close()

In [9]:
def get_row_count(db_name):
    """Generates SQL query to find total number of rows in a database table.

    Args:
        db_name: Name of the SQLite database file.

    Returns:
        SQL query as a string.
    """

    return "SELECT COUNT(*) FROM fashion_products;"



In [16]:
sql_query = get_row_count('/content/fashion_db.sqlite')
sql_query
# To execute the query use the read_sql_query function from previous code.
read_sql_query(sql_query, 'fashion_db.sqlite')

OperationalError: no such table: fashion_products

In [11]:
read_sql_query('SELECT * FROM fashion_products LIMIT 10;',
               "/content/fashion_db.sqlite")

OperationalError: no such table: fashion_products

# Define Prompt

In [None]:
prompt_parts_1 = [
  "You are an expert in converting English questions to SQL code! The SQL database has the name fashion_products and has the following columns - user_id, product_id, product_name, brand, category, price, color, and size.\n\nFor example,\nExample 1 - How many entries of Adidas are present?, the SQL command will be something like this\n``` SELECT COUNT(*) FROM fashion_products WHERE brand = 'Adidas';\n```\n\nExample 2 - How many XL products of Nike are there that have a rating of more than 4?\n```\nSELECT COUNT(*) FROM fashion_products WHERE brand = 'Nike' AND size = 'XL' AND \"Rating\" > 4;\n```\n\nExample 3 - \n```\nSELECT product_name FROM fashion_products WHERE price = (SELECT MAX(price) FROM fashion_products);\n```\n\nDont include ``` and \\n in the output",
]

In [None]:
question = "Tell me the id of the most expensive T-shirt?"

In [None]:
prompt_parts = [prompt_parts_1[0], question]
response = model.generate_content(prompt_parts)
response.text

"SELECT product_id FROM fashion_products WHERE product_name = 'T-shirt' AND price = (SELECT MAX(price) FROM fashion_products WHERE product_name = 'T-shirt');"

In [None]:
read_sql_query("""SELECT product_id FROM fashion_products WHERE product_name = 'T-shirt' AND price = (SELECT MAX(price) FROM fashion_products WHERE product_name = 'T-shirt');
""",
               "fashion_db.sqlite")

(938,)


# Combine it into Function

In [None]:
def generate_gemini_response(question, input_prompt):
    prompt_parts = [input_prompt, question]
    response = model.generate_content(prompt_parts)
    output = read_sql_query(response.text, "fashion_db.sqlite")
    return output

In [None]:
generate_gemini_response("How many products of Nike are there?",
                         prompt_parts_1[0])

(214,)
