# Instruct Fine-tune Tiny Llama for Text2SQL using Supervised Fine-tuning

The focus here is to take a labeled dataset containing SQL DB Schemas and queries in natural language and their corresponding SQL query.

Then we will use an LLM and train it using Supervised Fine-tuning to learn to generate the SQL query given a corresponding user question and database schema as depicted in the following workflow.

![](https://i.imgur.com/h8xFXON.png)

## Load up the Tiny Llama LLM

In [None]:
from transformers import AutoTokenizer

# Define the model to fine-tune
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# Load the tokenizer for the specified model.
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Set the padding token to be the same as the end of sentence token.
tokenizer.pad_token = tokenizer.eos_token

In [None]:
from transformers import BitsAndBytesConfig, AutoModelForCausalLM

# Define the quantization configuration for memory-efficient training.
bnb_config = BitsAndBytesConfig(
    # Load the model weights in 4-bit quantized format.
    load_in_4bit=True,
    # Specify the quantization type to use for 4-bit quantization.
    bnb_4bit_quant_type="nf4",
    # Specify the data type to use for computations during training.
    bnb_4bit_compute_dtype="float16",
    # Specify whether to use double quantization for 4-bit quantization.
    bnb_4bit_use_double_quant=True
)

# Load the model from the specified model ID and apply the quantization configuration.
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto"
)

In [None]:
model.device

## Test the LLM with a simple prompt

This LLM is already fine-tuned on public data, lets try a simple prompt

In [None]:
prompt_txt = "Explain Generative AI in 1 line"

messages = [
    {
        "role": "system",
        "content": "Act as a helpful assistant",
    },
    {"role": "user",
     "content": prompt_txt},
]

prompt = tokenizer.apply_chat_template(messages,
                                       tokenize=False,
                                       add_generation_prompt=True)
print(prompt)

In [None]:
# Encode the prompt.
inputs = tokenizer(prompt, return_tensors="pt").to('cuda')

# Generate the output.
output = model.generate(**inputs, max_new_tokens=200,
                        eos_token_id=tokenizer.eos_token_id,
                        tokenizer=tokenizer, stop_strings=["</s>"])

# Decode the output.
text = tokenizer.decode(output[0], skip_special_tokens=True)

In [None]:
print(text)

## Create Instruction Template for Instruction Tuning LLM for Text2SQL

Here we create a prompt which will use the available Text2SQL dataset and help the LLM learn to generate the query given the database schema context

In [None]:
def sql_chat_template(question, context):
    """
    Creates a chat template for the Llama model.

    Args:
        question: The question to be answered.
        context: The context information to be used for generating the answer.

    Returns:
        A string containing the chat template.
    """

    template = f"""\
    <|im_start|>user
    Given the following context, generate an SQL query for the following question.
    Just generate the query only and nothing else.
    Remember to only use the table columns in the context.
    context:{context}
    question:{question}
    <|im_end|>
    <|im_start|>assistant
    """
    # Remove any leading whitespace characters from each line in the template.
    template = "\n".join([line.lstrip() for line in template.splitlines()])
    return template

## Test Prompt with a Sample Data point on the LLM

In [None]:
question = "How many tackles for the player with over 0 fumble recovries and 0 forced fumbles?"
context = "CREATE TABLE table_name_96 (total VARCHAR, fumble_rec VARCHAR, fumble_force VARCHAR)"
sql_prompt = sql_chat_template(question,context)
print(sql_prompt)

In [None]:
# Encode the prompt.
inputs = tokenizer(sql_prompt, return_tensors="pt").to('cuda')
# Generate the output.
output = model.generate(**inputs, max_new_tokens=200,
                        eos_token_id=tokenizer.eos_token_id,
                        tokenizer=tokenizer, stop_strings=["</s>"])
# Decode the output.
text = tokenizer.decode(output[0], skip_special_tokens=True)
print(text)
# Human Answer:
# SELECT COUNT(total) FROM table_name_96 WHERE fumble_rec > 0 AND fumble_force = 0

We see that the LLM ends up writing a SQL query but its totally wrong

## Load Text2SQL Dataset

In [None]:
from datasets import load_dataset, Dataset
# Download the dataset for fine-tuning
dataset_id = "b-mc2/sql-create-context"
data = load_dataset(dataset_id, split="train")

# convert dataset to dataframe for simplicity
df = data.to_pandas()

In [None]:
df.shape

In [None]:
df.head()

We can clearly see for each data point (row) we have a user question, the database schema context and the ground truth SQL query which the LLM must learn to generate using the question and context.

## Instruction Tuning Dataset Preparation

Here we will use a modification of the previous prompt where we also add in the answer to the model learns to generate the answer given the other fields and prompt instructions.

In [None]:
def sql_chat_template_training(context, answer, question):
    """
    Creates a chat template for training the TinyLlama model.

    Args:
        question: The question to be answered.
        context: The context information to be used for generating the answer.'
        answer: The answer to be generated by the LLM

    Returns:
        A string containing the chat template.
    """

    template = <YOUR CODE HERE>
    # Remove any leading whitespace characters from each line in the template.
    template = "\n".join([line.lstrip() for line in template.splitlines()])
    return template

In [None]:
df["text"] = df.apply(lambda x: sql_chat_template_training(x["context"],
                                                           x["answer"],
                                                           x["question"]),
                      axis=1)

# Convert the dataframe back to a Dataset object.
sql_training_data = Dataset.from_pandas(df.head(30000))

In [None]:
sql_training_data

In [None]:
df.head()

In [None]:
print(df.iloc[0]['text'])

In [None]:
model.config

In [None]:
# Disable cache to improve training speed.
model.config.use_cache = False

## Setup PEFT LoRA Settings

In [None]:
model

In [None]:
from peft import LoraConfig

# Define the PEFT configuration.
peft_config = LoraConfig(
    # Set the rank of the LoRA projection matrix.
    r=8,
    # Set the alpha parameter for the LoRA projection matrix.
    lora_alpha=16,
    # Set the dropout rate for the LoRA projection matrix.
    lora_dropout=0.05,
    # Set the bias term to "none".
    bias="none",
    # Set the task type to "CAUSAL_LM".
    task_type="CAUSAL_LM"
)

In [None]:
sql_training_data

In [None]:
30000 // 32

In [None]:
500 * 32

## Setup Supervised Fine-tuning Training Config Settings

In [None]:
from transformers import TrainingArguments

# Define the training arguments.
training_args = TrainingArguments(
    # Set the output directory for the training run.
    output_dir="tiny_sql_llama",
    # Set the per-device training batch size.
    per_device_train_batch_size=32, # 32,
    # Set the number of gradient accumulation steps.
    gradient_accumulation_steps=1,
    # Set the optimizer to use.
    optim="paged_adamw_32bit",
    # Set the learning rate.
    learning_rate=2e-4,
    # Set the learning rate scheduler type.
    lr_scheduler_type="cosine",
    # Set the save strategy.
    save_strategy="steps",
    # Set the logging steps.
    logging_steps=50,
    save_steps=100,
    # Set the maximum number of training steps.
    max_steps=500,
    # Enable fp16 training.
    fp16=True,
)

In [None]:
from trl import SFTTrainer

# Initialize the SFTTrainer.
trainer = SFTTrainer(
    # Set the model to be trained.
    model=model,
    # Set the training dataset.
    train_dataset=sql_training_data,
    # Set the PEFT configuration.
    peft_config=peft_config,
    # Set the training arguments.
    args=training_args,
    # Set the tokenizer.
    tokenizer=tokenizer,
    # Set the name of the text field in the dataset.
    dataset_text_field="text",
    max_seq_length=1024
)

## Fine-tune LLM with Supervised Fine-tuning

Training model for around 500 steps with batch size 32 on roughly 16000 rows for 7-8 mins

Ideally you should train this on as much data as possible for text2sql to cover diverse schemas and queries based on your DB structure.

In [None]:
trainer.train()

## Save LoRA Adapter

In [None]:
trainer.save_model('tinyllama-text2sql')

In [None]:
# remove checkpoints
!rm -rf tiny_sql_llama

## Merge Text2SQL LoRA Adapter with LLM

In [None]:
from peft import AutoPeftModelForCausalLM, PeftModel
import torch

# Define the model ID.
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Load the pre-trained model.
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    load_in_8bit=False,
    device_map="auto",
    trust_remote_code=True
)

# Load the PEFT model from a checkpoint.
model_path = "./tinyllama-text2sql"
peft_model = PeftModel.from_pretrained(base_model,
                                       model_path,
                                       from_transformers=True,
                                       device_map="auto")

# Wrap the model with the PEFT model.
merged_llm = peft_model.merge_and_unload()

## Test Fine-tuned LLM

In [None]:
question = "How many tackles for the player with over 0 fumble recovries and 0 forced fumbles?"
context = "CREATE TABLE table_name_96 (total VARCHAR, fumble_rec VARCHAR, fumble_force VARCHAR)"
sql_prompt = <YOUR CODE HERE>
print(sql_prompt)

In [None]:
# Encode the prompt.
inputs = <YOUR CODE HERE>
# Generate the output.
output = <YOUR CODE HERE>
# Decode the output.
text = <YOUR CODE HERE>
print(text)
# Human Answer:
# SELECT COUNT(total) FROM table_name_96 WHERE fumble_rec > 0 AND fumble_force = 0

In [None]:
for row in df.tail(10).itertuples():
  question = row.question
  context = row.context
  sql_prompt = sql_chat_template(question,context)

  # Encode the prompt.
  inputs = tokenizer(sql_prompt, return_tensors="pt").to('cuda')
  # Generate the output.
  output = merged_llm.generate(**inputs, max_new_tokens=200,
                          eos_token_id=tokenizer.eos_token_id,
                          tokenizer=tokenizer, stop_strings=["<|im_end|>"])
  # Decode the output.
  text = tokenizer.decode(output[0], skip_special_tokens=True)
  print('Question:')
  print(question)
  print('Context:')
  print(context)
  print('AI Answer:')
  print(text.split('<|im_start|>assistant\n')[1].split('<|im_end|>')[0].strip('\n'))
  print('Human Answer:')
  print(row.answer)
  print('-'*30)
  print()