<a href="https://colab.research.google.com/github/solvedbrunus/lab-sql-generation-with-transformer-api/blob/main/lab-sql-generation-with-transformer-api.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SQL Generation with Transformer API

In [1]:
!pip install torch transformers bitsandbytes accelerate sqlparse



In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
torch.cuda.is_available()

True

In [4]:
available_memory = torch.cuda.get_device_properties(0).total_memory

In [5]:
print(available_memory)

15835660288


##Download the Model
Use any model on Colab (or any system with >30GB VRAM on your own machine) to load this in f16. If unavailable, use a GPU with minimum 8GB VRAM to load this in 8bit, or with minimum 5GB of VRAM to load in 4bit.

This step can take around 5 minutes the first time. So please be patient :)

In [6]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if available_memory > 15e9:
    # if you have atleast 15GB of GPU memory, run load the model in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    # else, load in 8 bits – this is a bit slower
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        # torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
        use_cache=True,
    )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

##Set the Question & Prompt and Tokenize
Feel free to change the schema in the prompt below to your own schema

In [7]:
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity

### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

##Generate the SQL
This can be excruciatingly slow on a T4 in Colab, and can take 10-20 seconds per query. On faster GPUs, this will take ~1-2 seconds

Ideally, you should use `num_beams`=4 for best results. But because of memory constraints, we will stick to just 1 for now.

In [8]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [9]:
question = "What was our revenue by product in the New York region last month?"
generated_sql = generate_query(question)

In [10]:
print(generated_sql)


SELECT p.product_id,
       SUM(s.quantity * p.price) AS revenue
FROM sales s
JOIN salespeople sp ON s.salesperson_id = sp.salesperson_id
JOIN products p ON s.product_id = p.product_id
WHERE sp.region = 'New York'
  AND s.sale_date >= (CURRENT_DATE - INTERVAL '1 month')
GROUP BY p.product_id
ORDER BY revenue DESC NULLS LAST;


# Exercise
 - Complete the prompts similar to what we did in class.
     - Try at least 3 versions
     - Be creative
 - Write a one page report summarizing your findings.
     - Were there variations that didn't work well? i.e., where GPT either hallucinated or wrong
 - What did you learn?

In [11]:
import sqlparse

def generate_query(prompt, question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

def run_all_queries(prompts, questions):
    results = {}
    for prompt_name, prompt in prompts.items():
        results[prompt_name] = {}
        for question in questions[prompt_name]:
            generated_sql = generate_query(prompt, question)
            results[prompt_name][question] = generated_sql
    return results


In [12]:
prompt1 = """### Task
Generate a SQL query to answer the following question: [QUESTION]{question}[/QUESTION]

### Instructions
- If the question cannot be answered with the available database schema, return 'I do not know'.
- Revenue is calculated as price multiplied by quantity.
- Cost is calculated as supply_price multiplied by quantity.

### Database Schema
The query will run on a database with the following schema:
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY,
  name VARCHAR(50),
  price DECIMAL(10,2),
  quantity INTEGER
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY,
   name VARCHAR(50),
   email VARCHAR(100)
);

CREATE TABLE orders (
  order_id INTEGER PRIMARY KEY,
  customer_id INTEGER,
  order_date DATE,
  total_amount DECIMAL(10,2)
);

CREATE TABLE order_items (
  order_item_id INTEGER PRIMARY KEY,
  order_id INTEGER,
  product_id INTEGER,
  quantity INTEGER,
  price DECIMAL(10,2)
);

-- orders.customer_id can be joined with customers.customer_id
-- order_items.order_id can be joined with orders.order_id
-- order_items.product_id can be joined with products.product_id

### Answer
Based on the database schema, here is the SQL query that answers the question: [QUESTION]{question}[/QUESTION]
[SQL]
"""

In [13]:
prompt2 = """### Task
Generate a SQL query to answer the following question: [QUESTION]{question}[/QUESTION]

### Instructions
- If the question cannot be answered with the available database schema, return 'I do not know'.

### Database Schema
The query will run on a database with the following schema:
CREATE TABLE students (
  student_id INTEGER PRIMARY KEY,
  name VARCHAR(50),
  major VARCHAR(50),
  enrollment_year INTEGER
);

CREATE TABLE courses (
  course_id INTEGER PRIMARY KEY,
  name VARCHAR(50),
  department VARCHAR(50)
);

CREATE TABLE enrollments (
  enrollment_id INTEGER PRIMARY KEY,
  student_id INTEGER,
  course_id INTEGER,
  semester VARCHAR(10),
  grade CHAR(1)
);

-- enrollments.student_id can be joined with students.student_id
-- enrollments.course_id can be joined with courses.course_id

### Answer
Based on the database schema, here is the SQL query that answers the question: [QUESTION]{question}[/QUESTION]
[SQL]
"""

In [14]:
prompt3 = """### Task
Generate a SQL query to answer the following question: [QUESTION]{question}[/QUESTION]

### Instructions
- If the question cannot be answered with the available database schema, return 'I do not know'.

### Database Schema
The query will run on a database with the following schema:
CREATE TABLE patients (
  patient_id INTEGER PRIMARY KEY,
  name VARCHAR(50),
  date_of_birth DATE,
  gender CHAR(1)
);

CREATE TABLE doctors (
  doctor_id INTEGER PRIMARY KEY,
  name VARCHAR(50),
  specialty VARCHAR(50)
);

CREATE TABLE appointments (
  appointment_id INTEGER PRIMARY KEY,
  patient_id INTEGER,
  doctor_id INTEGER,
  appointment_date DATE,
  diagnosis VARCHAR(100)
);

-- appointments.patient_id can be joined with patients.patient_id
-- appointments.doctor_id can be joined with doctors.doctor_id

### Answer
Based on the database schema, here is the SQL query that answers the question: [QUESTION]{question}[/QUESTION]
[SQL]
"""

In [15]:
prompt4 = """### Task
Generate a SQL query to answer the following question: [QUESTION]{question}[/QUESTION]

### Instructions
- If the question cannot be answered with the available database schema, return 'I do not know'.

### Database Schema
The query will run on a database with the following schema:
CREATE TABLE books (
  book_id INTEGER PRIMARY KEY,
  title VARCHAR(100),
  author VARCHAR(50),
  published_year INTEGER
);

CREATE TABLE members (
  member_id INTEGER PRIMARY KEY,
  name VARCHAR(50),
  membership_date DATE
);

CREATE TABLE borrowings (
  borrowing_id INTEGER PRIMARY KEY,
  book_id INTEGER,
  member_id INTEGER,
  borrow_date DATE,
  return_date DATE
);

-- borrowings.book_id can be joined with books.book_id
-- borrowings.member_id can be joined with members.member_id

### Answer
Based on the database schema, here is the SQL query that answers the question: [QUESTION]{question}[/QUESTION]
[SQL]
"""

In [16]:
prompts = {
    "ecommerce": prompt1,
    "university": prompt2,
    "healthcare": prompt3,
    "library": prompt4
}

In [19]:
questions = {
    "ecommerce": [
        "What was the total revenue generated last month?",
        "Which product had the highest sales in the last quarter?",
        "How many customers made purchases in the last year?",
        "What is the average order value for orders placed in the last month?",
        "Which customer has made the most purchases?"
    ],
    "university": [
        "How many students enrolled in the Computer Science department in 2022?",
        "Which course had the highest number of enrollments in the Fall semester?",
        "What is the average grade for students in the Mathematics department?",
        "How many students graduated in 2021?",
        "Which student has the highest GPA?"
    ],
    "healthcare": [
        "How many appointments were made in the last month?",
        "Which doctor had the most appointments in the last year?",
        "What is the most common diagnosis for patients over 60 years old?",
        "How many patients visited the cardiology department in the last quarter?",
        "Which patient has the most appointments?"
    ],
    "library": [
        "How many books were borrowed in the last month?",
        "Which book was borrowed the most in the last year?",
        "How many members joined the library in 2022?",
        "What is the average borrowing duration for books?",
        "Which member has borrowed the most books?"
    ]
    }

In [20]:
results = run_all_queries(prompts, questions)
for prompt_name, queries in results.items():
    print(f"Results for {prompt_name} schema:")
    for question, sql in queries.items():
        print(f"Question: {question}")
        print(f"Generated SQL: {sql}\n")

Results for ecommerce schema:
Question: What was the total revenue generated last month?
Generated SQL: 
SELECT SUM(o.total_amount) AS total_revenue
FROM orders o
WHERE o.order_date >= (CURRENT_DATE - INTERVAL '1 month');

Question: Which product had the highest sales in the last quarter?
Generated SQL: 
SELECT p.name,
       SUM(oi.price * oi.quantity) AS revenue,
       SUM(oi.price * oi.quantity) / NULLIF(COUNT(DISTINCT o.order_id), 0) AS avg_order_revenue
FROM products p
JOIN order_items oi ON p.product_id = oi.product_id
JOIN orders o ON oi.order_id = o.order_id
WHERE EXTRACT(QUARTER
              FROM o.order_date) = 4
  AND EXTRACT(YEAR
              FROM o.order_date) = EXTRACT(YEAR
                                           FROM CURRENT_DATE)
GROUP BY p.name
ORDER BY revenue DESC NULLS LAST
LIMIT 1;

Question: How many customers made purchases in the last year?
Generated SQL: 
SELECT COUNT(DISTINCT c.customer_id)
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id


# Conclusion

The model performed well in my opinion generating correct and relevant SQL queries for most questions. With questions that required information not present in the schema for example university the graduation year, GPA it reply that it didn’t know as expected in the prompt.

The generated SQL queries were generally relevant to the questions asked, demonstrating the model's ability to understand the context and requirements of the questions.

I should focus in improving the model ability to handle incomplete schemas and refining the prompts to better guide the model's query generation process.