In [None]:
!pip3 install --upgrade transformers datasets bitsandbytes peft

In [1]:
from datasets import load_dataset
from transformers import BitsAndBytesConfig, LlamaForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer
from peft import LoraConfig
import torch
import json
import re

# Prepare Model

In [2]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True)

In [3]:
model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
model.get_memory_footprint()

5591548160

In [5]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")

In [6]:
tokenizer.pad_token_id = tokenizer.eos_token_id

# Prepare Data

In [7]:
# prepare table for cosql and sparc
with open("./data/cosql_dataset/tables.json") as json_file:
    json_data = json.load(json_file)
    
database = dict()

for db in json_data:
    database[db["db_id"]] = dict()
    for table in db["table_names_original"]:
        database[db["db_id"]][table] = []
    for column in db["column_names_original"]:
        table_id = column[0]
        col_name = column[1]
        if table_id != -1:
            database[db["db_id"]][db["table_names_original"][table_id]].append(col_name)

In [8]:
spider = load_dataset("spider")

In [9]:
spider

DatasetDict({
    train: Dataset({
        features: ['db_id', 'query', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'],
        num_rows: 7000
    })
    validation: Dataset({
        features: ['db_id', 'query', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'],
        num_rows: 1034
    })
})

In [10]:
def db_id_to_table(db_id):
    global database
    db = database[db_id]
    table = [f"table:{table}\ntable_column: {','.join(db[table])}\n" for table in db]
    table = "---\n".join(table)
    table += "---\n"
    return table

In [11]:
def clean_text(text):
    text = text.replace("\xa0", " ").strip()
    text = re.sub(r'\s+', ' ', text)
    return text

def clean_query(query):
    query = clean_text(query).replace(" , ", ", ")
    if query[-1] == ";":
        query = query[:-1]
    return query

def get_prompt(tables, question):
    question = clean_text(question)
    prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>You are SQL expert. The user will give you a database table information and question in this form

table: table_name_1
table_column: column_1,column_2,column_3
---
table: table_name_2
table_column: column_1,column_2
---
question: user question

You have to answer valid SQL query.
<|eot_id|><|start_header_id|>user<|end_header_id|>{tables}question:{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
    return prompt

In [12]:
def spider_preprocess_function(examples):
    tables = db_id_to_table(examples["db_id"])
    examples["query"] = clean_query(examples["query"])
    examples["prompt"] = get_prompt(tables, examples["question"])
    return examples

In [13]:
spider_train = spider["train"].map(spider_preprocess_function, remove_columns=['db_id', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'])
spider_val = spider["validation"].map(spider_preprocess_function, remove_columns=['db_id', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'])

In [14]:
def spider_tokenize(examples):
    return tokenizer(examples["prompt"])

In [15]:
spider_train = spider_train.map(spider_tokenize, remove_columns=['prompt', 'query'])
spider_val = spider_val.map(spider_tokenize, remove_columns=['prompt', 'query'])

# Train

In [16]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")

In [17]:
for param in model.parameters():
    param.requires_grad = False  # freeze the model - train adapters later

In [18]:
peft_config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "o_proj"],
    lora_dropout=0.01,
    bias="none",
    task_type="CAUSAL_LM",
)

model.add_adapter(peft_config)

In [19]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

print_trainable_parameters(model)

trainable params: 44040192 || all params: 4584640512 || trainable%: 0.9606029498872953


In [20]:
train_args = TrainingArguments(
    output_dir="outputs",
    eval_strategy="epoch",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=32,
    learning_rate=1e-4,
    num_train_epochs=2,
    warmup_steps=10,
    logging_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    bf16=True,
)

In [21]:
trainer = Trainer(
    model=model,
    train_dataset=spider_train,
    eval_dataset=spider_val,
    args=train_args,
    data_collator=data_collator,
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

In [None]:
trainer.train()