<a href="https://colab.research.google.com/github/pbeles/lab-sql-generation-with-transformer-api/blob/main/lab-sql-generation-with-transformer-api.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SQL Generation with Transformer API

In [1]:
!pip install torch transformers bitsandbytes accelerate sqlparse

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.1-py3-none-manylinux_2_24_x86_64.whl.metadata (5.8 kB)
Downloading bitsandbytes-0.45.1-py3-none-manylinux_2_24_x86_64.whl (69.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 MB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.45.1


In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
torch.cuda.is_available()

True

In [4]:
available_memory = torch.cuda.get_device_properties(0).total_memory

In [5]:
print(available_memory)

42481811456


##Download the Model
Use any model on Colab (or any system with >30GB VRAM on your own machine) to load this in f16. If unavailable, use a GPU with minimum 8GB VRAM to load this in 8bit, or with minimum 5GB of VRAM to load in 4bit.

This step can take around 5 minutes the first time. So please be patient :)

In [7]:
from google.colab import userdata
HF_TOKEN= userdata.get('hugging_face')

In [8]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if available_memory > 15e9:
    # if you have atleast 15GB of GPU memory, run load the model in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    # else, load in 8 bits – this is a bit slower
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        # torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
        use_cache=True,
    )

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

##Set the Question & Prompt and Tokenize
Feel free to change the schema in the prompt below to your own schema

In [9]:
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity

### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

##Generate the SQL
This can be excruciatingly slow on a T4 in Colab, and can take 10-20 seconds per query. On faster GPUs, this will take ~1-2 seconds

Ideally, you should use `num_beams`=4 for best results. But because of memory constraints, we will stick to just 1 for now.

In [11]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=4,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [12]:
question = "What was our revenue by product in the New York region last month?"
generated_sql = generate_query(question)

In [13]:
print(generated_sql)


SELECT p.product_id,
       SUM(s.quantity * p.price) AS revenue
FROM sales s
JOIN salespeople sp ON s.salesperson_id = sp.salesperson_id
JOIN products p ON s.product_id = p.product_id
WHERE sp.region = 'New York'
  AND s.sale_date >= (CURRENT_DATE - INTERVAL '1 month')
GROUP BY p.product_id
ORDER BY revenue DESC NULLS LAST;


# Exercise
 - Complete the prompts similar to what we did in class.
     - Try at least 3 versions
     - Be creative
 - Write a one page report summarizing your findings.
     - Were there variations that didn't work well? i.e., where GPT either hallucinated or wrong
 - What did you learn?

In [17]:
def test_prompts():
    # Base schema from the original code
    base_schema = """
    CREATE TABLE products (
        product_id INTEGER PRIMARY KEY,
        name VARCHAR(50),
        price DECIMAL(10,2),
        quantity INTEGER
    );
    -- [rest of schema as shown in image]
    """

    # Test cases - different prompt variations
    test_prompts = [
        # Variation 1: Direct business question  (real word)
        """### Task
        Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

        ### Instructions
        - Return clear, optimized SQL queries
        - Join tables only when necessary
        - Use appropriate aggregations for analysis

        ### Database Schema
        {schema}
        """,

        # Variation 2: Step-by-step reasoning
        """### Task
        Follow these steps to generate SQL:
        1. Identify required tables from: {schema}
        2. Determine necessary joins
        3. Select appropriate columns
        4. Apply filters and aggregations
        5. Generate SQL for [QUESTION]{question}[/QUESTION]
        """,

        # Variation 3: Context-rich
        """### Task
        You are a database analyst helping business users.
        Question: [QUESTION]{question}[/QUESTION]

        Requirements:
        - Consider performance implications
        - Ensure data accuracy
        - Use clear column aliases

        Schema: {schema}
        """
         # Variation 4: Technical specification
        """### Technical Requirements
        - Query Type: Business Intelligence
        - Performance Target: Sub-second response
        - Data Consistency: ACID compliant

        ### Question Context
        [QUESTION]{question}[/QUESTION]

        ### Available Tables
        {schema}

        ### Output Requirements
        - Must include error handling
        - Results should be ordered meaningfully
        - Include appropriate indexing hints
        """
    ]

    # Test questions
    test_questions = [
        "What is the total revenue by region in the last quarter?",
        "Which products have the highest profit margin based on supply price?",
        "Who are the top 5 salespeople by number of units sold?",
        "What is the average order value by customer segment?",
        "Which regions have declining sales compared to last year?"
    ]

    results = []
    for prompt_template in test_prompts:
        prompt_results = []
        for question in test_questions:
            # Format the prompt with the current question and schema
            formatted_prompt = prompt_template.format(
                question=question,
                schema=base_schema
            )

            # Generate SQL using the function from the notebook
            sql_query = generate_query(formatted_prompt)
            prompt_results.append({
                'question': question,
                'generated_sql': sql_query
            })

        results.append(prompt_results)

    return results

    # Execute the tests
results = test_prompts()

    # Print results
    for i, prompt_results in enumerate(results, 1):
        print(f"\nPrompt Variation {i} Results:")
        print("-" * 50)
        for result in prompt_results:
          print(f"\nQuestion: {result['question']}")
          print(f"Generated SQL:\n{result['generated_sql']}\n")



Prompt Variation 1 Results:
--------------------------------------------------

Question: What is the total revenue by region in the last quarter?
Generated SQL:

SELECT s.region,
       SUM(p.price * p.quantity) AS total_revenue
FROM sales s
JOIN products p ON s.product_id = p.product_id
WHERE EXTRACT(QUARTER
              FROM s.sale_date) = 4
  AND EXTRACT(YEAR
              FROM s.sale_date) = EXTRACT(YEAR
                                          FROM CURRENT_DATE)
GROUP BY s.region
ORDER BY total_revenue DESC NULLS LAST;


Question: Which products have the highest profit margin based on supply price?
Generated SQL:

SELECT p.name,
       (p.price - ps.supply_price) AS profit_margin
FROM products p
JOIN product_suppliers ps ON p.product_id = ps.product_id
ORDER BY profit_margin DESC NULLS LAST;


Question: Who are the top 5 salespeople by number of units sold?
Generated SQL:

SELECT s.salesperson_id,
       s.name,
       SUM(s.quantity) AS total_quantity_sold
FROM sales s
GROUP 

After spending some time experimenting with different prompt structures, I've discovered some interesting patterns that might help improve SQL query generation. Here's what I found during my testing:
The Good Stuff:
Breaking down complex queries into step-by-step instructions worked surprisingly well. It's like having a GPS for database queries - each turn is clearly marked, making it harder to get lost. The model seemed most confident when we laid out a clear path: "First, grab these tables, then join here, then filter by that."
The Not-So-Great:
Sometimes the model got a bit too creative, especially with simpler queries. It's like using a sledgehammer to hang a picture - technically it works, but there are definitely better tools for the job. I noticed it particularly struggled with time-based comparisons, occasionally trying to reinvent the wheel instead of using standard date functions.
Key Takeaways:

Being explicit about what tables to join is like giving good directions - it just works better
When we break complex questions into smaller steps, we get fewer "creative interpretations"
Including performance hints in the prompt helped get more efficient queries
The model needs clear guardrails around schema context to avoid wandering off into non-existent tables

Interestingly, the more structured prompts consistently outperformed the free-form ones. It's like the difference between following a recipe and just winging it in the kitchen - both might work, but one has a higher success rate.
Worth noting that this isn't the final word - there's still room for experimentation and improvement. But it's a solid starting point for us looking to generate reliable SQL queries.