# SQL Generative QA

This workflow will finetune T5 Seq-to-Seq model on a SQL dataset that contains natural language questions, SQL context, and SQL results that correctly answer the question. The workflow provides the high level overview of how the system works. An accompanying application will allow a user to have this abstracted away with a pretrained model to interact with.

Main components:

Question: natural language questions\
Context: Building of table(s), relevant columns, and synethetic data\
Answer: Correct SQL query that answers the question

## Setup

In [1]:
import torch

device = 'gpu' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

if device == 'gpu':
    print(f'GPU: {torch.cuda.get_device_name(0)}')

Using gpu device
GPU: NVIDIA GeForce RTX 4070 SUPER


In [2]:
from datasets import load_dataset

sql_data = load_dataset('gretelai/synthetic_text_to_sql')

In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

model_name = "t5-small"
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


## Preprocessing

We only require the sql_prompt, sql_context, and sql columns in this dataset. These will be extracted and tokenized using F5's tokenizer.

In [4]:
from datasets import Dataset
def tokenize(inputs: Dataset) -> Dataset:
    '''
    Tokenizes the SQL prompt using the T5 tokenizer. 
    Question and context will be appended as inputs. The SQL query will be the label.

    Args:
        inputs (datasets.Dataset): The dataset to tokenize (train or test)

    Returns:
        datasets.Dataset: The tokenized dataset
    '''
    concat = [f"Translate to SQL: {q} Context: {c}" for q, c in zip(inputs["sql_prompt"], inputs["sql_context"])]
    tokenized_inputs = tokenizer(concat, max_length=500, padding="max_length", truncation=True, return_tensors='pt')
    tokenized_labels = tokenizer(inputs['sql'], max_length=500, padding="max_length", truncation=True, return_tensors='pt')
    
    return {
        'input_ids': tokenized_inputs['input_ids'],
        'attention_mask': tokenized_inputs['attention_mask'],
        'labels': tokenized_labels['input_ids']
    }

In [5]:
train_tokenized = sql_data['train'].map(tokenize, batched=True, remove_columns=sql_data['train'].column_names)

In [6]:
test_tokenized = sql_data['test'].map(tokenize, batched=True, remove_columns=sql_data['test'].column_names)

## Fine-tune F5

In [None]:
from transformers import TrainingArguments, Trainer, DefaultDataCollator

data_collator = DefaultDataCollator()

training_args = TrainingArguments(
    output_dir="sql_generator_f5",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=test_tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,

)

trainer.train()

## Assess results

In [13]:
def format_tests(question, context, model):
    '''
    Formats the question and context for the model to generate a SQL query

    Args:
        question (str): The question to ask the model
        context (str): The context to provide the model

    Returns:
        str: The formatted question and context
    '''
    input_text = f"Translate to SQL: {question} Context: {context}"
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to("cuda")

    output_ids = model.generate(**inputs, max_length=512, do_sample=True, temperature=0.6, top_k=50, top_p=0.95)
    generated_sql = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return generated_sql

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import sqlparse

model_path = "sql_generator_f5/checkpoint-75000"  
model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to("cuda")
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model.eval()

In [14]:
question = "What is the name of the employee with id 1?"
context = """
CREATE TABLE employees (id INT, name TEXT);
INSERT INTO employees (id, name) VALUES 
(1, 'Alice'), 
(2, 'Bob');
"""

generated_sql = format_tests(question, context, model)
sql_formatted = sqlparse.format(generated_sql, reindent=True, keyword_case='upper')
print(sql_formatted)

SELECT name
FROM employees
WHERE id = 1;


In [15]:
question = "How many employees work in each department?"
context = """
CREATE TABLE employees (id INT, name TEXT, department_id INT);
INSERT INTO employees (id, name, department_id) VALUES 
(1, 'Alice', 1), 
(2, 'Bob', 1), 
(3, 'Charlie', 2), 
(4, 'David', 2), 
(5, 'Eve', 3);

CREATE TABLE departments (department_id INT, department_name TEXT);
INSERT INTO departments (department_id, department_name) VALUES 
(1, 'HR'), 
(2, 'Engineering'), 
(3, 'Marketing');
"""

generated_sql = format_tests(question, context, model)
sql_formatted = sqlparse.format(generated_sql, reindent=True, keyword_case='upper')
print(sql_formatted)

SELECT department_name,
       COUNT(*) AS total_employees
FROM employees
GROUP BY department_id;


In [22]:
question = "What is the total quantity of each product sold in 2023?"
context = """
CREATE TABLE products (product_id INT, product_name TEXT);
INSERT INTO products (product_id, product_name) VALUES 
(1, 'Laptop'), 
(2, 'Phone'), 
(3, 'Tablet');

CREATE TABLE sales (sale_id INT, product_id INT, quantity INT, sale_date DATE);
INSERT INTO sales (sale_id, product_id, quantity, sale_date) VALUES 
(1, 1, 5, '2023-01-10'), 
(2, 2, 10, '2023-02-15'), 
(3, 2, 7, '2023-03-20'), 
(4, 3, 3, '2023-04-25'), 
(5, 1, 2, '2023-08-30'), 
(6, 3, 8, '2023-12-10');
"""

generated_sql = format_tests(question, context, model)
sql_formatted = sqlparse.format(generated_sql, reindent=True, keyword_case='upper')
print(sql_formatted)

SELECT p.product_name,
       SUM(s.quantity) AS total_quantity
FROM sales s
JOIN products p ON s.product_id = p.product_id
WHERE s.sale_date BETWEEN '2023-01-01' AND '2023-12-31'
GROUP BY p.product_name;


## Model Files

Model may be found on HuggingFace at https://huggingface.co/DevD60/sql_generator_f5