In [None]:
pip install -q -U bitsandbytes transformers peft accelerate datasets trl

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, Trainer
import torch
from huggingface_hub import login
login(
  token="", # ADD YOUR TOKEN HERE
  add_to_git_credential=True
)
import os
# disable Weights and Biases
os.environ['WANDB_DISABLED']="true"

In [None]:
def formatting_prompts_func(sample):
    output_texts = []
    for i in range(len(sample['context'])):
        INTRO="<bos>### You are an expert in writing SQL queries based on schema given"
        INSTRUCTION="\n Given below schema and a query, try to write a SQL script"
        schema_text="".join([part+"\n" for part in sample["context"][i].split(";")])
        SCHEMA=f"\n ###SCHEMA: {schema_text}" if schema_text else none
        QUERY=f"\n ###QUERY: {sample['question'][i]}"
        RESPONSE=f"\n\n ###RESPONSE: {sample['answer'][i]}"
        END="<eos>"
        formatted_prompt="\n".join([part for part in [INTRO,INSTRUCTION,SCHEMA,QUERY,RESPONSE,END]])
        output_texts.append(formatted_prompt)
    return output_texts

In [None]:
dataset = load_dataset("b-mc2/sql-create-context", split="train")
dataset = dataset.shuffle().select(range(12500))
dataset = dataset.train_test_split(test_size=2500/12500)

In [None]:
from trl import setup_chat_format

compute_dtype = getattr(torch, "float16")
quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=False,
    )
model_name="google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device_map = {"": 0}
model = AutoModelForCausalLM.from_pretrained(model_name,quantization_config=quantization_config,device_map=device_map)
tokenizer.padding_side = 'right' # to prevent warnings

In [None]:
from peft import LoraConfig
 
# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
        lora_alpha=32,
        lora_dropout=0.05,
        r=32,
        bias="none",
        target_modules=["q_proj", "k_proj", "v_proj", "dense"],
        task_type="CAUSAL_LM",
)

In [None]:
output_dir = f'./peft-gemma-2b-sql-SFTT'

args = TrainingArguments(
    output_dir=output_dir,                  # directory to save and repository id
    num_train_epochs=1,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=4,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=25,                       # log every 10 steps
    save_strategy="steps",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    fp16=True,                              # use bfloat16 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    max_steps=500,
    save_steps=100,
    overwrite_output_dir = 'True',
   
)

In [None]:
from trl import SFTTrainer,DataCollatorForCompletionOnlyLM
 
max_seq_length = 1024 # max sequence length for model and packing of the dataset
response_template = " ###RESPONSE:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    formatting_func=formatting_prompts_func,
    data_collator=collator
)

In [None]:
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()
 
# save model
trainer.save_model()

In [None]:
from peft import AutoPeftModelForCausalLM
from transformers import pipeline
peft_model = AutoPeftModelForCausalLM.from_pretrained(
  output_dir,
  device_map="auto",
  torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(output_dir)
pipe = pipeline("text-generation", model=peft_model, tokenizer=tokenizer)

In [None]:
def formatting_prompt(sample):
    INTRO="<bos>### You are an expert in writing SQL queries based on schema given"
    INSTRUCTION="\n Given below schema and a query, try to write a SQL script"
    schema_text="".join([part+"\n" for part in sample["context"].split(";")])
    SCHEMA=f"\n ###SCHEMA: {schema_text}" if schema_text else none
    QUERY=f"\n ###QUERY: {sample['question']}"
    RESPONSE=f"\n\n ###RESPONSE:"
#     END="<eos>"
    formatted_prompt="\n".join([part for part in [INTRO,INSTRUCTION,SCHEMA,QUERY,RESPONSE]])
    return formatted_prompt

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(model_name,quantization_config=quantization_config,device_map=device_map)
base_pipe = pipeline("text-generation", model=base_model, tokenizer=tokenizer)

In [None]:
from random import randint

eval_dataset = dataset['test']

for i in range(5):
    rand_idx = randint(0, len(eval_dataset))

    # Test on sample
    prompt = formatting_prompt(eval_dataset[rand_idx])
    input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")

    outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
    base_model_outputs = base_pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)

    print(f"Context:\n{eval_dataset[rand_idx]['context']}")
    print(f"Query:\n{eval_dataset[rand_idx]['question']}")
    print(f"Original Answer:\n{eval_dataset[rand_idx]['answer']}")
    
    print(f"Generated Answer:\n{outputs[0]['generated_text'].split('###RESPONSE:')[1]}")

    print(f"Base model Generated Answer:\n{base_model_outputs[0]['generated_text'].split('###RESPONSE:')[1]}")

    print("\n\n")