# Instruct Fine-tune Tiny Llama for Text2SQL using Supervised Fine-tuning

The focus here is to take a labeled dataset containing SQL DB Schemas and queries in natural language and their corresponding SQL query.

Then we will use an LLM and train it using Supervised Fine-tuning to learn to generate the SQL query given a corresponding user question and database schema as depicted in the following workflow.

![](https://i.imgur.com/h8xFXON.png)

## Load up the Tiny Llama LLM

In [1]:
from transformers import AutoTokenizer

# Define the model to fine-tune
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# Load the tokenizer for the specified model.
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Set the padding token to be the same as the end of sentence token.
tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

In [2]:
from transformers import BitsAndBytesConfig, AutoModelForCausalLM

# Define the quantization configuration for memory-efficient training.
bnb_config = BitsAndBytesConfig(
    # Load the model weights in 4-bit quantized format.
    load_in_4bit=True,
    # Specify the quantization type to use for 4-bit quantization.
    bnb_4bit_quant_type="nf4",
    # Specify the data type to use for computations during training.
    bnb_4bit_compute_dtype="float16",
    # Specify whether to use double quantization for 4-bit quantization.
    bnb_4bit_use_double_quant=True
)

# Load the model from the specified model ID and apply the quantization configuration.
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto"
)

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [3]:
model.device

device(type='cuda', index=0)

## Test the LLM with a simple prompt

This LLM is already fine-tuned on public data, lets try a simple prompt

In [4]:
prompt_txt = "Explain Generative AI in 1 line"

messages = [
    {
        "role": "system",
        "content": "Act as a helpful assistant",
    },
    {"role": "user",
     "content": prompt_txt},
]

prompt = tokenizer.apply_chat_template(messages,
                                       tokenize=False,
                                       add_generation_prompt=True)
print(prompt)

<|system|>
Act as a helpful assistant</s>
<|user|>
Explain Generative AI in 1 line</s>
<|assistant|>



In [7]:
# Encode the prompt.
inputs = tokenizer(prompt, return_tensors="pt").to('cuda')

# Generate the output.
output = model.generate(**inputs, max_new_tokens=500,
                        eos_token_id=tokenizer.eos_token_id,
                        tokenizer=tokenizer, stop_strings=["</s>"])

# Decode the output.
text = tokenizer.decode(output[0], skip_special_tokens=True)

In [8]:
print(text)

<|system|>
Act as a helpful assistant 
<|user|>
Explain Generative AI in 1 line 
<|assistant|>
Generative AI is a type of AI that can generate new content based on pre-existing data. It is a rapidly growing field of research and development, with potential applications in fields such as language translation, music composition, and image generation.


## Create Instruction Template for Instruction Tuning LLM for Text2SQL

Here we create a prompt which will use the available Text2SQL dataset and help the LLM learn to generate the query given the database schema context

In [9]:
def sql_chat_template(question, context):
    """
    Creates a chat template for the Llama model.

    Args:
        question: The question to be answered.
        context: The context information to be used for generating the answer.

    Returns:
        A string containing the chat template.
    """

    template = f"""\
    <|im_start|>user
    Given the following context, generate an SQL query for the following question.
    Just generate the query only and nothing else.
    Remember to only use the table columns in the context.
    context:{context}
    question:{question}
    <|im_end|>
    <|im_start|>assistant
    """
    # Remove any leading whitespace characters from each line in the template.
    template = "\n".join([line.lstrip() for line in template.splitlines()])
    return template

## Test Prompt with a Sample Data point on the LLM

In [10]:
question = "How many tackles for the player with over 0 fumble recovries and 0 forced fumbles?"
context = "CREATE TABLE table_name_96 (total VARCHAR, fumble_rec VARCHAR, fumble_force VARCHAR)"
sql_prompt = sql_chat_template(question,context)
print(sql_prompt)

<|im_start|>user
Given the following context, generate an SQL query for the following question.
Just generate the query only and nothing else.
Remember to only use the table columns in the context.
context:CREATE TABLE table_name_96 (total VARCHAR, fumble_rec VARCHAR, fumble_force VARCHAR)
question:How many tackles for the player with over 0 fumble recovries and 0 forced fumbles?
<|im_end|>
<|im_start|>assistant



In [11]:
# Encode the prompt.
inputs = tokenizer(sql_prompt, return_tensors="pt").to('cuda')
# Generate the output.
output = model.generate(**inputs, max_new_tokens=200,
                        eos_token_id=tokenizer.eos_token_id,
                        tokenizer=tokenizer, stop_strings=["</s>"])
# Decode the output.
text = tokenizer.decode(output[0], skip_special_tokens=True)
print(text)
# Human Answer:
# SELECT COUNT(total) FROM table_name_96 WHERE fumble_rec > 0 AND fumble_force = 0

<|im_start|>user
Given the following context, generate an SQL query for the following question.
Just generate the query only and nothing else.
Remember to only use the table columns in the context.
context:CREATE TABLE table_name_96 (total VARCHAR, fumble_rec VARCHAR, fumble_force VARCHAR)
question:How many tackles for the player with over 0 fumble recovries and 0 forced fumbles?
<|im_end|>
<|im_start|>assistant
SELECT COUNT(DISTINCT total) AS total_fumbles_recoveries, COUNT(DISTINCT fumble_rec) AS fumble_recoveries_for_player_with_over_0_fumbles_recovries, COUNT(DISTINCT fumble_force) AS fumble_forces_for_player_with_over_0_fumbles_recovries FROM table_name_96 WHERE total > 0 AND fumble_rec > 0 AND fumble_force > 0;


We see that the LLM ends up writing a SQL query but its totally wrong

## Load Text2SQL Dataset

In [12]:
from datasets import load_dataset, Dataset
# Download the dataset for fine-tuning
dataset_id = "b-mc2/sql-create-context"
data = load_dataset(dataset_id, split="train")

# convert dataset to dataframe for simplicity
df = data.to_pandas()

README.md:   0%|          | 0.00/4.43k [00:00<?, ?B/s]

sql_create_context_v4.json:   0%|          | 0.00/21.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/78577 [00:00<?, ? examples/s]

In [13]:
df.shape

(78577, 3)

In [14]:
df.head()

Unnamed: 0,answer,question,context
0,SELECT COUNT(*) FROM head WHERE age > 56,How many heads of the departments are older th...,CREATE TABLE head (age INTEGER)
1,"SELECT name, born_state, age FROM head ORDER B...","List the name, born state and age of the heads...","CREATE TABLE head (name VARCHAR, born_state VA..."
2,"SELECT creation, name, budget_in_billions FROM...","List the creation year, name and budget of eac...","CREATE TABLE department (creation VARCHAR, nam..."
3,"SELECT MAX(budget_in_billions), MIN(budget_in_...",What are the maximum and minimum budget of the...,CREATE TABLE department (budget_in_billions IN...
4,SELECT AVG(num_employees) FROM department WHER...,What is the average number of employees of the...,CREATE TABLE department (num_employees INTEGER...


We can clearly see for each data point (row) we have a user question, the database schema context and the ground truth SQL query which the LLM must learn to generate using the question and context.

## Instruction Tuning Dataset Preparation

Here we will use a modification of the previous prompt where we also add in the answer to the model learns to generate the answer given the other fields and prompt instructions.

In [15]:
def sql_chat_template_training(context, answer, question):
    """
    Creates a chat template for training the TinyLlama model.

    Args:
        question: The question to be answered.
        context: The context information to be used for generating the answer.'
        answer: The answer to be generated by the LLM

    Returns:
        A string containing the chat template.
    """

    template = f"""\
    <|im_start|>user
    Given the following context, generate an SQL query for the following question.
    Just generate the query only and nothing else.
    Remember to only use the table columns in the context.
    context:{context}
    question:{question}
    <|im_end|>
    <|im_start|>assistant
    {answer}
    <|im_end|>
    """
    # Remove any leading whitespace characters from each line in the template.
    template = "\n".join([line.lstrip() for line in template.splitlines()])
    return template

In [16]:
df["text"] = df.apply(lambda x: sql_chat_template_training(x["context"],
                                                           x["answer"],
                                                           x["question"]),
                      axis=1)

# Convert the dataframe back to a Dataset object.
sql_training_data = Dataset.from_pandas(df.head(30000))

In [17]:
sql_training_data

Dataset({
    features: ['answer', 'question', 'context', 'text'],
    num_rows: 30000
})

In [18]:
df.head()

Unnamed: 0,answer,question,context,text
0,SELECT COUNT(*) FROM head WHERE age > 56,How many heads of the departments are older th...,CREATE TABLE head (age INTEGER),"<|im_start|>user\nGiven the following context,..."
1,"SELECT name, born_state, age FROM head ORDER B...","List the name, born state and age of the heads...","CREATE TABLE head (name VARCHAR, born_state VA...","<|im_start|>user\nGiven the following context,..."
2,"SELECT creation, name, budget_in_billions FROM...","List the creation year, name and budget of eac...","CREATE TABLE department (creation VARCHAR, nam...","<|im_start|>user\nGiven the following context,..."
3,"SELECT MAX(budget_in_billions), MIN(budget_in_...",What are the maximum and minimum budget of the...,CREATE TABLE department (budget_in_billions IN...,"<|im_start|>user\nGiven the following context,..."
4,SELECT AVG(num_employees) FROM department WHER...,What is the average number of employees of the...,CREATE TABLE department (num_employees INTEGER...,"<|im_start|>user\nGiven the following context,..."


In [19]:
print(df.iloc[0]['text'])

<|im_start|>user
Given the following context, generate an SQL query for the following question.
Just generate the query only and nothing else.
Remember to only use the table columns in the context.
context:CREATE TABLE head (age INTEGER)
question:How many heads of the departments are older than 56 ?
<|im_end|>
<|im_start|>assistant
SELECT COUNT(*) FROM head WHERE age > 56
<|im_end|>



In [20]:
model.config

LlamaConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 22,
  "num_key_value_heads": 4,
  "pretraining_tp": 1,
  "quantization_config": {
    "_load_in_4bit": true,
    "_load_in_8bit": false,
    "bnb_4bit_compute_dtype": "float16",
    "bnb_4bit_quant_storage": "uint8",
    "bnb_4bit_quant_type": "nf4",
    "bnb_4bit_use_double_quant": true,
    "llm_int8_enable_fp32_cpu_offload": false,
    "llm_int8_has_fp16_weight": false,
    "llm_int8_skip_modules": null,
    "llm_int8_threshold": 6.0,
    "load_in_4bit": true,
    "load_in_8bit":

In [21]:
# Disable cache to improve training speed.
model.config.use_cache = False

## Setup PEFT LoRA Settings

In [22]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear4bit(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear4bit(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), e

In [23]:
from peft import LoraConfig

# Define the PEFT configuration.
peft_config = LoraConfig(
    # Set the rank of the LoRA projection matrix.
    r=8,
    # Set the alpha parameter for the LoRA projection matrix.
    lora_alpha=16,
    # Set the dropout rate for the LoRA projection matrix.
    lora_dropout=0.05,
    # Set the bias term to "none".
    bias="none",
    # Set the task type to "CAUSAL_LM".
    task_type="CAUSAL_LM"
)

In [24]:
sql_training_data

Dataset({
    features: ['answer', 'question', 'context', 'text'],
    num_rows: 30000
})

In [25]:
30000 // 32

937

In [26]:
500 * 32

16000

## Setup Supervised Fine-tuning Training Config Settings

In [27]:
from trl import SFTConfig

# Define the training arguments.
training_args = SFTConfig(
    # Set the output directory for the training run.
    output_dir="tiny_sql_llama",
    # Set the per-device training batch size.
    per_device_train_batch_size=32, # 32,
    # Set the number of gradient accumulation steps.
    gradient_accumulation_steps=1,
    # Set the optimizer to use.
    optim="paged_adamw_32bit",
    # Set the learning rate.
    learning_rate=2e-4,
    # Set the learning rate scheduler type.
    lr_scheduler_type="cosine",
    # Set the save strategy.
    save_strategy="steps",
    # Set the logging steps.
    logging_steps=50,
    save_steps=100,
    # Set the maximum number of training steps.
    max_steps=500,
    # Enable fp16 training.
    fp16=True,
    # Set the name of the text field in the dataset.
    dataset_text_field="text",
    max_seq_length=1024
)

In [28]:
from trl import SFTTrainer

# Initialize the SFTTrainer.
trainer = SFTTrainer(
    # Set the model to be trained.
    model=model,
    # Set the training dataset.
    train_dataset=sql_training_data,
    # Set the PEFT configuration.
    peft_config=peft_config,
    # Set the training arguments.
    args=training_args,
    # Set the tokenizer.
    processing_class=tokenizer
)

Converting train dataset to ChatML:   0%|          | 0/30000 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/30000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/30000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/30000 [00:00<?, ? examples/s]

In [29]:
trainer.train_dataset[0]

{'answer': 'SELECT COUNT(*) FROM head WHERE age > 56',
 'question': 'How many heads of the departments are older than 56 ?',
 'context': 'CREATE TABLE head (age INTEGER)',
 'text': '<|im_start|>user\nGiven the following context, generate an SQL query for the following question.\nJust generate the query only and nothing else.\nRemember to only use the table columns in the context.\ncontext:CREATE TABLE head (age INTEGER)\nquestion:How many heads of the departments are older than 56 ?\n<|im_end|>\n<|im_start|>assistant\nSELECT COUNT(*) FROM head WHERE age > 56\n<|im_end|>\n',
 'input_ids': [1,
  529,
  29989,
  326,
  29918,
  2962,
  29989,
  29958,
  1792,
  13,
  29954,
  5428,
  278,
  1494,
  3030,
  29892,
  5706,
  385,
  3758,
  2346,
  363,
  278,
  1494,
  1139,
  29889,
  13,
  14084,
  5706,
  278,
  2346,
  871,
  322,
  3078,
  1683,
  29889,
  13,
  7301,
  1096,
  304,
  871,
  671,
  278,
  1591,
  4341,
  297,
  278,
  3030,
  29889,
  13,
  4703,
  29901,
  27045,
  10

In [30]:
tokenizer.decode(trainer.train_dataset[0]['input_ids'])

'<s> <|im_start|>user\nGiven the following context, generate an SQL query for the following question.\nJust generate the query only and nothing else.\nRemember to only use the table columns in the context.\ncontext:CREATE TABLE head (age INTEGER)\nquestion:How many heads of the departments are older than 56 ?\n<|im_end|>\n<|im_start|>assistant\nSELECT COUNT(*) FROM head WHERE age > 56\n<|im_end|>\n'

## Fine-tune LLM with Supervised Fine-tuning

Training model for around 500 steps with batch size 32 on roughly 16000 rows for 7-8 mins

Ideally you should train this on as much data as possible for text2sql to cover diverse schemas and queries based on your DB structure.

In [31]:
trainer.train()

Step,Training Loss
50,1.3389
100,0.5827
150,0.5443
200,0.5177
250,0.5094
300,0.5028
350,0.4955
400,0.4945
450,0.4927
500,0.4875


TrainOutput(global_step=500, training_loss=0.5965973892211914, metrics={'train_runtime': 462.9946, 'train_samples_per_second': 34.558, 'train_steps_per_second': 1.08, 'total_flos': 2.148912365292749e+16, 'train_loss': 0.5965973892211914})

## Save LoRA Adapter

In [32]:
trainer.save_model('tinyllama-text2sql')

In [33]:
# remove checkpoints
!rm -rf tiny_sql_llama

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


## Merge Text2SQL LoRA Adapter with LLM

In [34]:
from peft import AutoPeftModelForCausalLM, PeftModel
import torch

# Define the model ID.
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Load the pre-trained model.
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    load_in_8bit=False,
    device_map="auto",
    trust_remote_code=True
)

# Load the PEFT model from a checkpoint.
model_path = "./tinyllama-text2sql"
peft_model = PeftModel.from_pretrained(base_model,
                                       model_path,
                                       from_transformers=True,
                                       device_map="auto")

# Wrap the model with the PEFT model.
merged_llm = peft_model.merge_and_unload()

## Test Fine-tuned LLM

In [50]:
question = "How many tackles for the player with over 0 fumble recovries and 0 forced fumbles?"
context = "CREATE TABLE table_name_96 (total VARCHAR, fumble_rec VARCHAR, fumble_force VARCHAR)"
sql_prompt = sql_chat_template(question,context)
print(sql_prompt)

<|im_start|>user
Given the following context, generate an SQL query for the following question.
Just generate the query only and nothing else.
Remember to only use the table columns in the context.
context:CREATE TABLE table_name_96 (total VARCHAR, fumble_rec VARCHAR, fumble_force VARCHAR)
question:How many tackles for the player with over 0 fumble recovries and 0 forced fumbles?
<|im_end|>
<|im_start|>assistant



In [51]:
# Encode the prompt.
inputs = tokenizer(sql_prompt, return_tensors="pt").to('cuda')
# Generate the output.
output = model.generate(**inputs, max_new_tokens=200,
                        eos_token_id=tokenizer.eos_token_id,
                        tokenizer=tokenizer, stop_strings=["<|im_end|>"])
# Decode the output.
text = tokenizer.decode(output[0], skip_special_tokens=True)
print(text)
# Human Answer:
# SELECT COUNT(total) FROM table_name_96 WHERE fumble_rec > 0 AND fumble_force = 0

<|im_start|>user
Given the following context, generate an SQL query for the following question.
Just generate the query only and nothing else.
Remember to only use the table columns in the context.
context:CREATE TABLE table_name_96 (total VARCHAR, fumble_rec VARCHAR, fumble_force VARCHAR)
question:How many tackles for the player with over 0 fumble recovries and 0 forced fumbles?
<|im_end|>
<|im_start|>assistant
SELECT COUNT(total) FROM table_name_96 WHERE fumble_rec > 0 AND fumble_force = 0
<|im_end|>


In [52]:
for row in df.tail(20).itertuples():
  question = row.question
  context = row.context
  sql_prompt = sql_chat_template(question,context)

  # Encode the prompt.
  inputs = tokenizer(sql_prompt, return_tensors="pt").to('cuda')
  # Generate the output.
  output = merged_llm.generate(**inputs, max_new_tokens=200,
                          eos_token_id=tokenizer.eos_token_id,
                          tokenizer=tokenizer, stop_strings=["<|im_end|>"])
  # Decode the output.
  text = tokenizer.decode(output[0], skip_special_tokens=True)
  print('Question:')
  print(question)
  print('Context:')
  print(context)
  print('AI Answer:')
  print(text.split('<|im_start|>assistant\n')[1].split('<|im_end|>')[0].strip('\n'))
  print('Human Answer:')
  print(row.answer)
  print('-'*30)
  print()

Question:
Which score has a Record of 48-55?
Context:
CREATE TABLE table_name_66 (score VARCHAR, record VARCHAR)
AI Answer:
SELECT score FROM table_name_66 WHERE record = "48-55"
Human Answer:
SELECT score FROM table_name_66 WHERE record = "48-55"
------------------------------

Question:
Which score has a Record of 44-49?
Context:
CREATE TABLE table_name_96 (score VARCHAR, record VARCHAR)
AI Answer:
SELECT score FROM table_name_96 WHERE record = "44-49"
Human Answer:
SELECT score FROM table_name_96 WHERE record = "44-49"
------------------------------

Question:
Which score has an Opponent of white sox, and a Record of 2-0?
Context:
CREATE TABLE table_name_92 (score VARCHAR, opponent VARCHAR, record VARCHAR)
AI Answer:
SELECT score FROM table_name_92 WHERE opponent = "white sox" AND record = "2-0"
Human Answer:
SELECT score FROM table_name_92 WHERE opponent = "white sox" AND record = "2-0"
------------------------------

Question:
How many votes did candice sjostrom receive?
Context:


## Enter API Tokens

In [38]:
from getpass import getpass

# OPENAI_KEY = getpass('Enter your OpenAI Key: ')
GROQ_API_KEY = getpass('Enter your Groq API Key: ')

Enter your Groq API Key:  ········


## Setup Environment Variables

In [40]:
import os

# os.environ['OPENAI_API_KEY'] = OPENAI_KEY
os.environ['GROQ_API_KEY'] = GROQ_API_KEY

In [55]:
from langchain_groq import ChatGroq

# alternate model in case quota is over - llama3-70b-8192
llm = ChatGroq(model_name="llama-3.2-90b-vision-preview",
               temperature=0)

In [57]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

prompt = """Act as an expert grader for SQL queries.
Your task is to verify the correctness of a SQL query generated by a language model
by comparing it against the human golden reference query given below.
You will also be given table schemas and the natural language query for which the SQL query was generated.

Grade each query with either PASS or FAIL. If you grade it as fail, generate a Reason
on why you graded it as fail. 

Natural Language Query: {nl_query}
Schema: {context}

Model Query: {ai_query}
Human Golden Reference Query: {human_query}

Do not explain each step of the grading. 
Response should be as follows:
Result: PASS or FAIL
Reason: <reason if fail>
"""

prompt_template = ChatPromptTemplate.from_template(prompt)
llm_as_judge = prompt_template | llm | StrOutputParser()

In [58]:
prompt_template

ChatPromptTemplate(input_variables=['ai_query', 'context', 'human_query', 'nl_query'], input_types={}, partial_variables={}, messages=[HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['ai_query', 'context', 'human_query', 'nl_query'], input_types={}, partial_variables={}, template='Act as an expert grader for SQL queries.\nYour task is to verify the correctness of a SQL query generated by a language model\nby comparing it against the human golden reference query given below.\nYou will also be given table schemas and the natural language query for which the SQL query was generated.\n\nGrade each query with either PASS or FAIL. If you grade it as fail, generate a Reason\non why you graded it as fail. \n\nNatural Language Query: {nl_query}\nSchema: {context}\n\nModel Query: {ai_query}\nHuman Golden Reference Query: {human_query}\n\nDo not explain each step of the grading. \nResponse should be as follows:\nResult: PASS or FAIL\nReason: <reason if fail>\n'), additional_kwar

In [59]:
for row in df.tail(20).itertuples():
  question = row.question
  context = row.context
  sql_prompt = sql_chat_template(question,context)

  # Encode the prompt.
  inputs = tokenizer(sql_prompt, return_tensors="pt").to('cuda')
  # Generate the output.
  output = merged_llm.generate(**inputs, max_new_tokens=200,
                          eos_token_id=tokenizer.eos_token_id,
                          tokenizer=tokenizer, stop_strings=["<|im_end|>"])
  # Decode the output.
  text = tokenizer.decode(output[0], skip_special_tokens=True)
  ai_answer = text.split('<|im_start|>assistant\n')[1].split('<|im_end|>')[0].strip('\n')
  human_answer = row.answer
  print('Question:')
  print(question)
  print('Context:')
  print(context)
  print('AI Answer:')
  print(ai_answer)
  print('Human Answer:')
  print(human_answer)
  print('LLM as Judge Grade:')
  print(llm_as_judge.invoke({
      'ai_query': ai_answer, 
      'context': context, 
      'human_query': human_answer, 
      'nl_query': question
  }))
  print()
  print('-'*30)
  print()

Question:
Which score has a Record of 48-55?
Context:
CREATE TABLE table_name_66 (score VARCHAR, record VARCHAR)
AI Answer:
SELECT score FROM table_name_66 WHERE record = "48-55"
Human Answer:
SELECT score FROM table_name_66 WHERE record = "48-55"
LLM as Judge Grade:
Result: PASS
Reason:

------------------------------

Question:
Which score has a Record of 44-49?
Context:
CREATE TABLE table_name_96 (score VARCHAR, record VARCHAR)
AI Answer:
SELECT score FROM table_name_96 WHERE record = "44-49"
Human Answer:
SELECT score FROM table_name_96 WHERE record = "44-49"
LLM as Judge Grade:
Result: PASS
Reason:

------------------------------

Question:
Which score has an Opponent of white sox, and a Record of 2-0?
Context:
CREATE TABLE table_name_92 (score VARCHAR, opponent VARCHAR, record VARCHAR)
AI Answer:
SELECT score FROM table_name_92 WHERE opponent = "white sox" AND record = "2-0"
Human Answer:
SELECT score FROM table_name_92 WHERE opponent = "white sox" AND record = "2-0"
LLM as Judge