# **Installing all of the packages**

In [1]:
#!pip install -U accelerate bitsandbytes peft transformers==4.39 datasets trl git-lfs wandb flash-attn sql-metadata scipy sqlglot

In [2]:
import os
import torch
import pandas as pd
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig, TaskType
from datasets import load_dataset
from sql_metadata import Parser
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from tqdm import tqdm

# **Connect to wandb**

In [3]:
os.environ["WANDB_PROJECT"]="qwen_finetuning"

# **Loading the model**

In [4]:
from modelscope import AutoModelForCausalLM, AutoTokenizer
model_name = "qwen/CodeQwen1.5-7B-Chat"
# model_name = "mistralai/Mistral-7B-Instruct-v0.2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_compute_dtype = torch.float16,
#     bnb_4bit_quant_type='nf4',
#     bnb_4bit_use_double_quant = True
# )

model = AutoModelForCausalLM.from_pretrained(
    "model_name",
    #quantization_config=bnb_config,
    attn_implementation="flash_attention_2",
    torch_dtype = torch.bfloat16,
    device_map='auto',
)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained("model_name")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"

2024-04-24 01:27:29,578 - modelscope - INFO - PyTorch version 2.0.1 Found.
2024-04-24 01:27:29,580 - modelscope - INFO - Loading ast index from /hpc2hdd/home/jzhao815/.cache/modelscope/ast_indexer
2024-04-24 01:27:29,641 - modelscope - INFO - Loading done! Current index file version is 1.13.0, with md5 8962eeb014e66d8db494c341ba0f48c1 and a total number of 972 components indexed


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


In [6]:
print(model)

GPTBigCodeForCausalLM(
  (transformer): GPTBigCodeModel(
    (wte): Embedding(49152, 4096)
    (wpe): Embedding(8192, 4096)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-41): 42 x GPTBigCodeBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTBigCodeFlashAttention2(
          (c_attn): Linear(in_features=4096, out_features=4352, bias=True)
          (c_proj): Linear(in_features=4096, out_features=4096, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTBigCodeMLP(
          (c_fc): Linear(in_features=4096, out_features=16384, bias=True)
          (c_proj): Linear(in_features=16384, out_features=4096, bias=True)
          (act): GELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((4096,),

# **Loading the dataset**

In [7]:
data_files = {"train": "./training/spider_filtered_finetuning_dataset.csv", "validation": "./validation/filtered_spider_syn_dataset.csv"}
dataset = load_dataset('csv', data_files=data_files)

**Filtering rows with max tokens**

In [8]:
def formatting_prompts_func(training_dataset):
  output_texts = []
  for i in range(len(training_dataset['question'])):
    question = training_dataset['question'][i]
    correct_tables = training_dataset['correct_tables'][i]
    correct_columns = training_dataset['correct_columns'][i]
    database_schema = training_dataset['database_schema'][i]
    if correct_columns:
        correct_columns = ", ".join(set(correct_columns.split(", ")))
    correct_tables = ", ".join(set(correct_tables.split(", ")))
    user_message = f"""Given the following SQL tables, your job is to determine the columns and tables that the question is referring to.
{database_schema}
###
Question: {question}
"""
    assitant_message = f"""
```SQL
-- Columns: {correct_columns}
-- Tables: {correct_tables} ;
```
"""
    messages = [
    {"role": "user", "content": user_message},
    {"role": "assistant", "content": assitant_message},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False)
    output_texts.append(text)
  return output_texts

In [9]:
response_template = "<|im_start|>" #Qwencode
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

# **Training Config**

In [10]:
lora_r = 64
lora_alpha = 32
lora_dropout = 0.1
output_dir = "./SFT"
num_train_epochs = 3
bf16 = True
overwrite_output_dir = True
per_device_train_batch_size = 16
per_device_eval_batch_size = 2
gradient_accumulation_steps = 16
gradient_checkpointing = True
evaluation_strategy = "steps"
learning_rate = 5e-5
weight_decay = 0.01
lr_scheduler_type = "cosine"
warmup_ratio = 0.01
max_grad_norm = 0.3
group_by_length = True
auto_find_batch_size = True
save_steps = 50
logging_steps = 10
load_best_model_at_end= False
packing = False
save_total_limit=3
neftune_noise_alpha=5
report_to="wandb"
max_seq_length = 2100 #set based on the maximum number of tokens

In [11]:
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    target_modules=[
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head"
    ],
    task_type=TaskType.CAUSAL_LM,
)

In [12]:
training_arguments = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=overwrite_output_dir,
    num_train_epochs=num_train_epochs,
    load_best_model_at_end=load_best_model_at_end,
    per_device_train_batch_size=per_device_train_batch_size,
    evaluation_strategy=evaluation_strategy,
    max_grad_norm = max_grad_norm,
    auto_find_batch_size = auto_find_batch_size,
    save_total_limit = save_total_limit,
    gradient_accumulation_steps=gradient_accumulation_steps,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    bf16=bf16,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to=report_to,
    neftune_noise_alpha= neftune_noise_alpha
)

In [13]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    peft_config=peft_config,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
    tokenizer=tokenizer,
    args=training_arguments,
    max_seq_length=max_seq_length,
    packing=packing
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [14]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mjzhao815[0m ([33mzjj2323[0m). Use [1m`wandb login --relogin`[0m to force relogin


CREATE TABLE `circuits` (
  circuitId INTEGER,
  circuitRef TEXT,
  name TEXT,
  location TEXT,
  country TEXT,
  lat REAL,
  lng REAL,
  alt TEXT,
  url TEXT
);
Sample rows from `circuits`:
1, albert_park, Albert Park Grand Prix Circuit, Melbourne, Australia, -37.8497, 144.968, 10, http://en.wikipedia.org/wiki/Melbourne_Grand_Prix_Circuit
2, sepang, Sepang International Circuit, Kuala Lumpur, Malaysia, 2.76083, 101.738,, http://en.wikipedia.org/wiki/Sepang_International_Circuit
3, bahrain, Bahrain International Circuit, Sakhir, Bahrain, 26.0325, 50.5106,, http://en.wikipedia.org/wiki/Bahrain_International_Circuit

CREATE TABLE `races` (
  raceId INTEGER,
  year INTEGER,
  round INTEGER,
  circuitId INTEGER REFERENCES circuits(circuitId),
  name TEXT,
  date TEXT,
  time TEXT,
  url TEXT
);
Sample rows from `races`:
1, 2009, 1, 1, Australian Grand Prix, 2009-03-29, 06:00:00, http://en.wikipedia.org/wiki/2009_Australian_Grand_Prix
2, 2009, 2, 2, Malaysian Grand Prix, 2009-04-05, 09:00:0

Step,Training Loss,Validation Loss


CREATE TABLE `Document_Types` (
  document_type_code VARCHAR(10) PRIMARY KEY,
  document_description VARCHAR(255)
);
Sample rows from `Document_Types`:
APP, Initial Application
REG, Regular

CREATE TABLE `Documents` (
  document_id INTEGER,
  document_type_code VARCHAR(10) REFERENCES Document_Types(document_type_code),
  grant_id INTEGER REFERENCES Grants(grant_id),
  sent_date DATETIME,
  response_received_date DATETIME,
  other_details VARCHAR(255)
);
Sample rows from `Documents`:
1, APP, 5, 1986-11-30 07:56:35, 1977-12-01 02:18:53, 
2, APP, 13, 2004-01-23 11:57:08, 1979-12-08 10:38:07, 
3, REG, 10, 1999-03-03 12:25:58, 1995-09-12 13:13:48, 

CREATE TABLE `Grants` (
  grant_id INTEGER,
  organisation_id INTEGER REFERENCES Organisations(organisation_id),
  grant_amount DECIMAL(19,4),
  grant_start_date DATETIME,
  grant_end_date DATETIME,
  other_details VARCHAR(255)
);
Sample rows from `Grants`:
1, 10, 4094.542, 2016-11-20 00:18:51, 2004-10-24 09:09:39, et
2, 3, 281.2446, 1985-10-09 

In [None]:
output_dir = os.path.join("./", "final_checkpoint_part1_2")
trainer.model.save_pretrained(output_dir)

