In [1]:
!pip install torch transformers bitsandbytes accelerate sqlparse

Collecting transformers
  Using cached transformers-4.37.2-py3-none-any.whl.metadata (129 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.42.0-py3-none-any.whl.metadata (9.9 kB)
Collecting accelerate
  Using cached accelerate-0.27.2-py3-none-any.whl.metadata (18 kB)
Collecting sqlparse
  Downloading sqlparse-0.4.4-py3-none-any.whl.metadata (4.0 kB)
Collecting huggingface-hub<1.0,>=0.19.3 (from transformers)
  Using cached huggingface_hub-0.20.3-py3-none-any.whl.metadata (12 kB)
Collecting regex!=2019.12.17 (from transformers)
  Using cached regex-2023.12.25-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting tokenizers<0.19,>=0.14 (from transformers)
  Using cached tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting safetensors>=0.4.1 (from transformers)
  Using cached safetensors-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Using cached transformers-4

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [5]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/515 [00:00<?, ?B/s]

In [6]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto",
    use_cache=True,
)

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.59G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [7]:
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity

### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

In [8]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [9]:
question = "What was our revenue by product in the New York region last month?"
generated_sql = generate_query(question)

In [10]:
print(generated_sql)


SELECT p.product_id,
       SUM(s.quantity * p.price) AS revenue
FROM sales s
JOIN salespeople sp ON s.salesperson_id = sp.salesperson_id
JOIN products p ON s.product_id = p.product_id
WHERE sp.region = 'New York'
  AND s.sale_date >= (CURRENT_DATE - INTERVAL '1 month')
GROUP BY p.product_id
ORDER BY revenue DESC NULLS LAST;


In [18]:
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that amount is price multiplied by quantity
- Remember to handle NULL in attribte columns

### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
  product_id STRING PRIMARY KEY, -- Unique ID for each product
  product_category STRING, -- Name of the product category
  price FLOAT, -- Price of each unit of the product
);

CREATE TABLE customers (
   user_id STRING PRIMARY KEY, -- Unique ID for each customer
   customer_city STRING, -- Name of the customer city
);

CREATE TABLE order (
  order_id STRING PRIMARY KEY, -- Unique ID for each order
  timestamp TIMESTAMP, -- Timestamp of order
  user_id STRING,  -- ID of customer who made purchase
  product_id STRING, -- ID of product sold
  quantity FLOAT, -- Quantity of product sold
  price FLOAT, -- Price of product sold
  review_score FLOAT -- Rating for order
);

-- order.product_id can be joined with products.product_id
-- order.user_id can be joined with customers.user_id

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

In [12]:
question = "What was our revenue by product in the New York last month?"
generated_sql = generate_query(question)

In [13]:
print(generated_sql)


SELECT p.product_id,
       SUM(o.price * o.quantity) AS total_revenue
FROM
order o
JOIN customers c ON o.user_id = c.user_id
JOIN products p ON o.product_id = p.product_id
WHERE c.customer_city ilike '%New York%'
  AND o.timestamp >= (CURRENT_DATE - INTERVAL '1 month')
GROUP BY p.product_id
ORDER BY total_revenue DESC NULLS LAST;


In [14]:
question = "How many products are there in each category?"
generated_sql = generate_query(question)
print(generated_sql)


SELECT p.product_category,
       COUNT(p.product_id) AS product_count
FROM products p
GROUP BY p.product_category
ORDER BY product_count DESC NULLS LAST;


In [21]:
question = "What categories are not bought by a customer?"
generated_sql = generate_query(question)
print(generated_sql)


SELECT p.product_category
FROM products p
LEFT JOIN
order o ON p.product_id = o.product_id
WHERE o.product_id IS NULL
GROUP BY p.product_category
ORDER BY p.product_category NULLS LAST;


In [22]:
question = "What categories are not bought group by customer?"
generated_sql = generate_query(question)
print(generated_sql)


SELECT c.user_id,
       p.product_category
FROM customers c
LEFT JOIN
order o ON c.user_id = o.user_id
LEFT JOIN products p ON o.product_id = p.product_id
WHERE o.product_id IS NULL
GROUP BY c.user_id,
         p.product_category;
