# Natural Language to SQL using Google's Gemma

[**Bhavesh Bhatt - Link to my YouTube Channel**](https://www.youtube.com/BhaveshBhatt8791?sub_confirmation=1)

Click on the link below to open a Colab version of the notebook. You will be able to create your own version.

<a href="https://colab.research.google.com/github/bhattbhavesh91/google-gemma-finetuning-n2sql/blob/main/n2sql-google-gemma-finetuning-notebook.ipynb" target="_blank"><img height="40" alt="Run your own notebook in Colab" src = "https://colab.research.google.com/assets/colab-badge.svg"></a>

In [1]:
!pip3 install -q -U bitsandbytes==0.42.0
# %%capture
!pip install transformers datasets accelerate peft huggingface_hub hf_transfer flash-attn trl wandb -qU


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import os
import transformers
import torch
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig, GemmaTokenizer
from transformers import TrainingArguments
from IPython.display import Markdown
from peft import LoraConfig, get_peft_model

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Is Bfloat16 avaiable?: {torch.cuda.is_bf16_supported()}")

Is Bfloat16 avaiable?: True


In [4]:
from huggingface_hub import login

login(
  token="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", # ADD YOUR TOKEN HERE
)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /opt/app-root/src/.cache/huggingface/token
Login successful


In [5]:
# model_id = "google/gemma-2b"
model_id = "google/gemma-7b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [6]:
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             # quantization_config=bnb_config,
                                             device_map="auto",
                                             torch_dtype=torch.bfloat16,
                                             # device_map="cuda:0",
                                             attn_implementation="flash_attention_2"
                                    )
tokenizer = AutoTokenizer.from_pretrained(model_id)



`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
print(f"Vocabulary size of Gemma7B: {len(tokenizer.get_vocab()):,}")

Vocabulary size of Gemma7B: 256,000


In [8]:
generation_config = {
    "max_new_tokens": 100,
    "do_sample": True,
    "temperature": 1,
    "top_k": 100,
    "top_p":0.90,
}

In [9]:
input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(text=input_text, return_tensors="pt").to(device)
outputs = model.generate(**input_ids, **generation_config)
Markdown(tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True))

Write me a poem about Machine Learning.

When we were kids

We looked at the sky

And wondered

How far it went

We saw the night’s black emptiness

And wondered

How much there was to see

But there were no lights

To show us

What was there to see

Today

We see the stars

We see the wonders of the universe

We see the galaxies

We see the black holes

We see the planets

We see the asteroids

We see the comets

We see the mete

In [10]:
text = "Quote: Our doubts are traitors,"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Quote: Our doubts are traitors, and make us lose the good we oft might win, by fearing to attempt.

This quote is


In [11]:
template = "You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.\n\n" + \
"You must output the SQL query that answers the question.\n\n" + \
"### Input:\n" + \
"```{question}```\n\n" + \
"### Context:\n" + \
"```{context}```\n\n"
# "### Response:\n" + \
# "```{response}```"

In [12]:
# lora_config = LoraConfig(
#     r = 8,
#     target_modules = ["q_proj", "o_proj", "k_proj", "v_proj",
#                       "gate_proj", "up_proj", "down_proj"],
#     task_type = "CAUSAL_LM",
# )

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

peft_model = get_peft_model(model=model, peft_config=lora_config)

In [13]:
# data = load_dataset("b-mc2/sql-create-context")
data = load_dataset("b-mc2/sql-create-context", split="train")

data = data.map(lambda samples: tokenizer(samples["question"],
                                          samples["context"]), batched=True)

train_test_split = data.train_test_split(test_size=100, seed=1399, shuffle=True)
train_data = train_test_split["train"].shuffle()
val_data = train_test_split["test"].shuffle()
print(len(train_data), len(val_data))

78477 100


In [14]:
def formatting_func(example):
    text = f"Question: {example['question'][0]}\nContext: {example['context'][0]}\nAnswer: {example['answer'][0]}"
    return [text]

In [15]:
args_definition = dict(
    output_dir="./gemma7bit-lora-sql",
    overwrite_output_dir=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=3e-4,
    max_steps=500,
    lr_scheduler_type="cosine",
    max_grad_norm = 0.3,
    warmup_steps=2,
    logging_steps=2,
    save_steps=2,
    logging_first_step=True,
    seed=1399,
    bf16=True,
    report_to="wandb",
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    load_best_model_at_end=True
)
args = TrainingArguments(**args_definition)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [16]:
trainer = SFTTrainer(
    model=model,
    # train_dataset=data["train"],
    train_dataset=train_data,
    eval_dataset=val_data,
    args=args,
    # args=transformers.TrainingArguments(
    #     per_device_train_batch_size=2,
    #     gradient_accumulation_steps=2,
    #     warmup_steps=2,
    #     max_steps=75,
    #     learning_rate=2e-4,
    #     fp16=True,
    #     logging_steps=1,
    #     output_dir="outputs",
    #     optim="paged_adamw_8bit"
    # ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)


# trainer = SFTTrainer(
#     model=peft_model,
#     args=args,
#     train_dataset=train_data,
#     eval_dataset=val_data,
#     tokenizer=tokenizer,
#     peft_config=peft_config,
#     formatting_func=formatting_func,
#     max_seq_length=1024,
#     packing=True,
#     callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
# )

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [17]:
trainer.train()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[34m[1mwandb[0m: Currently logged in as: [33mliuxiangwin[0m ([33mliuxiangwin-free[0m). Use [1m`wandb login --relogin`[0m to force relogin
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Step,Training Loss,Validation Loss
2,47.4397,112.096146
4,54.9563,113.032021
6,43.0701,105.088303
8,29.3374,93.756378
10,24.013,70.502617
12,5.7244,70.364449
14,6.7112,69.091835
16,5.139,67.759377
18,5.658,64.892502
20,3.348,62.908638


TrainOutput(global_step=500, training_loss=3.085471360683441, metrics={'train_runtime': 2025.1541, 'train_samples_per_second': 0.247, 'train_steps_per_second': 0.247, 'total_flos': 886627990499328.0, 'train_loss': 3.085471360683441, 'epoch': 0.0063712935000063715})

In [18]:
fine_tuned_model = peft_model.merge_and_unload()

In [19]:
torch.manual_seed(42)
sample = train_data[torch.randint(low=0, high=len(train_data), size=(1,)).item()]

In [20]:
display(Markdown("#### Completion:"))
display(Markdown(tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)))
display(Markdown("#### Answer:"))
Markdown(sample["answer"])

#### Completion:

Quote: Our doubts are traitors, and make us lose the good we oft might win, by fearing to attempt.

This quote is

#### Answer:

SELECT MAX(pop__2010_) FROM table_18600760_13 WHERE latitude = "48.676125"

In [21]:
torch.cuda.empty_cache()

not_tuned_model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    # device_map="cuda:0",
    attn_implementation="flash_attention_2"
)

fine_tuned_model.use_cache = True

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [22]:
def generate_responses(example, ft_model, og_model):
    prompt = template.format(context=example["context"], question=example["question"])
    input_ids = tokenizer(text=prompt, return_tensors="pt").to(device)
    ft_outputs = ft_model.generate(**input_ids, **generation_config)
    og_outputs = og_model.generate(**input_ids, **generation_config)

    display(Markdown("#### Prompt:"))
    display(Markdown(prompt))
    display(Markdown("#### Original Completion:"))
    display(Markdown(tokenizer.decode(token_ids=og_outputs[0], skip_special_tokens=True) \
           .replace(prompt, "")))
    display(Markdown("#### Fine-tuned Completion:"))
    display(Markdown(tokenizer.decode(token_ids=ft_outputs[0], skip_special_tokens=True) \
           .replace(prompt, "")))
    display(Markdown("#### Expected Answer:"))
    display(Markdown("`{answer}`".format(answer=example["answer"])))
    display(Markdown("-----------------------------"))

In [23]:
for i in range(5):
    generate_responses(val_data[i], ft_model=fine_tuned_model, og_model=not_tuned_model)

#### Prompt:

You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.

You must output the SQL query that answers the question.

### Input:
```What is Division One's average Other Apps, with a League Goal less than 1?```

### Context:
```CREATE TABLE table_name_76 (other_apps INTEGER, division VARCHAR, league_goals VARCHAR)```



#### Original Completion:

```INSERT INTO table_name_76 (other_apps, division, league_goals) VALUES (289852566, 'Division Two', 23)```

```INSERT INTO table_name_76 (other_apps, division, league_goals) VALUES (914577381, 'Division Two', 21)```

```INSERT INTO table_name_76 (other_apps, division,

#### Fine-tuned Completion:

###Output:CREATE TABLE table_name_7 (division VARCHAR, league_goals VARCHAR)What is Division One's average Other Apps, with a League Goals less than 1?CREATE TABLE table_name_77 (other_apps VARCHAR, league_goals VARCHAR)CREATE TABLE table_name_7 (division VARCHAR, league_goals VARCHAR)CREATE TABLE table_name VARCHAR, league_goals VARCHAR)What is Division One's average Other Apps

#### Expected Answer:

`SELECT AVG(other_apps) FROM table_name_76 WHERE division = "one" AND league_goals < 1`

-----------------------------

#### Prompt:

You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.

You must output the SQL query that answers the question.

### Input:
```What is the Branding of the Frequency owned by Sound of Faith Broadcasting Group?```

### Context:
```CREATE TABLE table_name_23 (branding VARCHAR, owner VARCHAR)```



#### Original Completion:

|   branding  |   owner    |
| :----------- | :----------|
|   American  |   National |
|   Digital   |   Digital  |
|   Family    |   National |
|   Frequency |   Sound    |

### Expected Output:
```SELECT branding FROM table_name_23 WHERE owner = 'Sound of Faith Broadcasting Group';```

#### Input:
```What is the owner of the Frequency owned by National?```



#### Fine-tuned Completion:

###Output:CREATE TABLE table_name_12 (frequency VARCHAR, branding VARCHAR, owner VARCHAR)CREATE TABLE table_name_10 (branding VARCHAR, owner VARCHAR)CREATE TABLE table_name_1 (frequency VARCHAR, branding VARCHAR, owner VARCHAR)CREATE TABLE table_name_9 (frequency VARCHAR, branding VARCHAR, owner VARCHAR)What is the name of a table named table_name_8?CREATE TABLE table_name_9_1 (

#### Expected Answer:

`SELECT branding FROM table_name_23 WHERE owner = "sound of faith broadcasting group"`

-----------------------------

#### Prompt:

You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.

You must output the SQL query that answers the question.

### Input:
```Who is safe if John and Nicole are eliminated?```

### Context:
```CREATE TABLE table_12305325_4 (safe VARCHAR, eliminated VARCHAR)```



#### Original Completion:

```INSERT INTO table_12305325_4 (safe, eliminated) VALUES ('John', 'Tom'), ('Nicole', 'Henry'), ('Tony', 'Frank')```

### Output:
```SELECT safe FROM table_12305325_4 WHERE eliminated IN ('John', 'Nicole')```

### Scoring:
This model is evaluated on SQL evaluation server. The input to the SQL evaluation server is a triple $(x,C,

#### Fine-tuned Completion:

###Expected output:CREATE TABLE table_12305325_4 (safe VARCHAR, eliminated VARCHAR)CREATE TABLE table_12305325_1 (safe VARCHAR, eliminated VARCHAR)CREATE TABLE table_12305325_1 (safe VARCHAR, eliminated VARCHAR)CREATE TABLE table_12305325_1 (safe VARCHAR, eliminated VARCHAR)

### Input:
```Who is

#### Expected Answer:

`SELECT safe FROM table_12305325_4 WHERE eliminated = "John and Nicole"`

-----------------------------

#### Prompt:

You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.

You must output the SQL query that answers the question.

### Input:
```What is the lowest against that has less than 9 wins, and draws smaller than 0?```

### Context:
```CREATE TABLE table_name_84 (against INTEGER, wins VARCHAR, draws VARCHAR)```



#### Original Completion:

### Correct output:
```SELECT against FROM table_name_84 WHERE wins < 9 AND draws < 0 GROUP BY against```

### Incorrect output:
```SELECT against FROM table_name_84 WHERE wins < 9 AND draws < 0```

### Input:
```Which teams are older than 175, and score between 1 and 17?```

### Context:
```CREATE TABLE table_name_85 (name VARCHAR

#### Fine-tuned Completion:

```CREATE TABLE table_name_14 (draw VARCHAR)```CREATE TABLE table_name_12 (loss VARCHAR)CREATE TABLE table_name_13 (low VARCHAR, loss VARCHAR)CREATE TABLE table_name_85 (loss VARCHAR, low VARCHAR)

### Questions:CREATE TABLE table_name_16 (loss VARCHAR)CREATE TABLE table_name_17 (low VARCHAR)
SELECT loss VARCHAR, low VARCHAR

You are a powerful

#### Expected Answer:

`SELECT MIN(against) FROM table_name_84 WHERE wins < 9 AND draws < 0`

-----------------------------

#### Prompt:

You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.

You must output the SQL query that answers the question.

### Input:
```Which model was made from 2001–2004, with a Torque of n·m (lb·ft) @ 3750, and an Engine code of n42b18 / n46b18?```

### Context:
```CREATE TABLE table_name_72 (model VARCHAR, engine_code VARCHAR, years VARCHAR, torque VARCHAR)```



#### Original Completion:

```INSERT INTO table_name_72 (model, engine_code, years, torque) VALUES ("BMW X3", "n42b18 / n46b18", "2001-2004", "251.75 N·m / 187 lb·ft @ 3750");```

### Output:
```SELECT * FROM table_name_72 WHERE engine_code = "n

#### Fine-tuned Completion:

```CREATE TABLE table_name_73 (years VARCHAR, torque VARCHAR)```CREATE TABLE table_name_74 (model VARCHAR, torque VARCHAR)CREATE TABLE table_name_74 (years VARCHAR, torque VARCHAR)CREATE TABLE table_name_75 (model VARCHAR, years VARCHAR, torque VARCHAR)```CREATE TABLE table_name_76 (model VARCHAR, torque VARCHAR)```CREATE TABLE table_name_77 (model VARCHAR, years VARCHAR,

#### Expected Answer:

`SELECT model FROM table_name_72 WHERE years = "2001–2004" AND torque = "n·m (lb·ft) @ 3750" AND engine_code = "n42b18 / n46b18"`

-----------------------------

In [24]:
model_save_name = "gemma7b-ft-lora-sql-v2"

In [25]:
# Save model & tokenizer
fine_tuned_model.push_to_hub(model_save_name)
tokenizer.push_to_hub(model_save_name)

Upload 4 LFS files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.11G [00:00<?, ?B/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [26]:
# Save adapters
trainer.push_to_hub(model_save_name + "adapters")

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/5.43k [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Liu-Xiang/gemma7bit-lora-sql/commit/c712f0ec701d35f136a75b2b893563e9d89f6dbf', commit_message='gemma7b-ft-lora-sql-v2adapters', commit_description='', oid='c712f0ec701d35f136a75b2b893563e9d89f6dbf', pr_url=None, pr_revision=None, pr_num=None)