In [None]:
from datasets import load_dataset
from random import randrange
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, AutoPeftModelForCausalLM
from trl import SFTTrainer
import mlflow
import shutil
from os.path import dirname

In [None]:
def remove_dir(dir_path):
    try:
        shutil.rmtree(dir_path)
        print(f"Folder '{dir_path}' has been deleted.")
    except Exception as e:
        # Ignore errors, you can print a message if needed
        print(f"Folder '{dir_path}' has been deleted.")

base_model = "bloom-1b1"
base_model_name = "bloom-1b1"
merged_model = "merged_ft_model"
training_output = "training_output" # stores the checkpoints
#dataset_name = "Shreyasrp/Text-to-SQL"
dataset_name = "text-to-sql_dataset"
split = "train[:10%]"
device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Device to be used is {device_map}")
print(f"--------------------------------------\n")
# Remove the model folders if exist
remove_dir(training_output) 
remove_dir(merged_model)

In [None]:
# BitsAndBytesConfig config
#bnb_config = BitsAndBytesConfig(
#    load_in_4bit=True,
#    bnb_4bit_use_double_quant=True,
#    bnb_4bit_quant_type="nf4",
#    bnb_4bit_compute_dtype="float16"
#)

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

peft_config = LoraConfig(
      r=16,
      lora_alpha=32,
      lora_dropout=0.05,
      bias="none",
      task_type="CAUSAL_LM",
)

def prompt_instruction_format(sample):
  return f"""Context:
    {sample['instruction']}

    Result:
    {sample['output']}
    """

In [None]:
dataset = load_dataset(dataset_name, split=split)

In [None]:
#base_model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=bnb_config, use_cache = False, device_map=device_map)
base_model = AutoModelForCausalLM.from_pretrained(base_model, use_cache = False, device_map=device_map)

In [None]:
def print_param_precision(model):
  dtypes = {}
  for _, p in model.named_parameters():
      dtype = p.dtype
      if dtype not in dtypes:
          dtypes[dtype] = 0
      dtypes[dtype] += p.numel()
  total = 0
  for k, v in dtypes.items():
      total += v
  for k, v in dtypes.items():
      print(f"{k}, {v / 10**6:.4f} M, {v / total*100:.2f} %")

def print_trainable_parameters(model):
  # Count the total parameters
  total_params = sum(p.numel() for p in model.parameters())
  print(f"Total parameters: {total_params/10**6:.4f} M")

  # Count the trainable parameters
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  print(f"Trainable parameters: {trainable_params/10**6:.4f} M")
    

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"{device} Memory Used: {base_model.get_memory_footprint() / 1024**2:.4f} MB")
print(f"--------------------------------------\n")
print(f"Parameters loaded for model {base_model_name}:")
print_trainable_parameters(base_model)
print("\n")
print(f"Data types for loaded model {base_model_name}:")
print_param_precision(base_model)

#for name, param in base_model.named_parameters():
#    print(f"Parameter name: {name}, Data type: {param.dtype}")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

In [None]:
trainingArgs = TrainingArguments(
    output_dir=training_output,
    num_train_epochs=3,
    #per_device_train_batch_size=4,
    auto_find_batch_size=True,
    #gradient_checkpointing=True, # When enabled, memory can be freed at the cost of small decrease in the training speed due to recomputing parts of the graph during back-propagation.
    #gradient_accumulation_steps=2,
    #optim="paged_adamw_32bit",
    #optim="paged_adamw_8bit",
    logging_steps=5,
    save_strategy="epoch",
    learning_rate=2e-4,
    #fp16=False,
    #bf16=False,
    #lr_scheduler_type="cosine",
    disable_tqdm=True
)

trainer = SFTTrainer(
    model=base_model,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length=2048,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=prompt_instruction_format,
    args=trainingArgs,
)

In [None]:
print("Start Fine-Tuning")
mlflow.set_experiment("Fine-Tune TRL")
trainer.train()
print("Training Done")

In [None]:
trainer.save_model() # adapter models
print("Model Saved")

In [None]:
trained_model = AutoPeftModelForCausalLM.from_pretrained(
    trainingArgs.output_dir,
    return_dict=True,
    device_map=device_map
)

# Merge LoRA adapter with the base model and save the merged model
lora_merged_model = trained_model.merge_and_unload()
lora_merged_model.save_pretrained(merged_model)
tokenizer.save_pretrained(merged_model)

In [None]:
# Reset the iPython kernel before running the following test.

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def print_param_precision(model):
  dtypes = {}
  for _, p in model.named_parameters():
      dtype = p.dtype
      if dtype not in dtypes:
          dtypes[dtype] = 0
      dtypes[dtype] += p.numel()
  total = 0
  for k, v in dtypes.items():
      total += v
  for k, v in dtypes.items():
      print(f"{k}, {v / 10**6:.4f} M, {v / total*100:.2f} %")

def print_trainable_parameters(model):
  # Count the total parameters
  total_params = sum(p.numel() for p in model.parameters())
  print(f"Total parameters: {total_params/10**6:.4f} M")

  # Count the trainable parameters
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  print(f"Trainable parameters: {trainable_params/10**6:.4f} M")
    

In [1]:
merged_model = "merged_ft_model"
device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
model_new = AutoModelForCausalLM.from_pretrained(merged_model, device_map="auto")

print(f"{device_map} Memory Used: {model_new.get_memory_footprint() / 1024**2:.4f} MB")
print("\nParameters:")
print_trainable_parameters(model_new)
print("\nData types:")
print_param_precision(model_new)

cuda:0 Memory Used: 4063.8516 MB

Parameters:
Total parameters: 1065.3143 M
Trainable parameters: 1065.3143 M

Data types:
torch.float32, 1065.3143 M, 100.00 %


In [3]:
base_model = "bloom-1b1"
device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
old_model = AutoModelForCausalLM.from_pretrained(merged_model, device_map="auto")

print(f"{device_map} Memory Used: {old_model.get_memory_footprint() / 1024**2:.4f} MB")
print("\nParameters:")
print_trainable_parameters(old_model)
print("\nData types:")
print_param_precision(old_model)

cuda:0 Memory Used: 4063.8516 MB

Parameters:
Total parameters: 1065.3143 M
Trainable parameters: 1065.3143 M

Data types:
torch.float32, 1065.3143 M, 100.00 %


In [4]:
#ft_tokenizer = AutoTokenizer.from_pretrained(merged_model,torch_dtype=torch.bfloat16,device_map=device_map)
#ft_model = AutoModelForCausalLM.from_pretrained(merged_model,torch_dtype=torch.bfloat16,device_map=device_map)
ft_tokenizer = AutoTokenizer.from_pretrained(merged_model,device_map=device_map)
ft_model = AutoModelForCausalLM.from_pretrained(merged_model,device_map=device_map)
#loaded_models[mymodel1] = {"tokenizer": tokenizer1, "model": model1}
#tokenizer1 = loaded_models[mymodel1]["tokenizer"]
#base_model = loaded_models[mymodel1]["model"]

#loaded_models = {}
base_tokenizer = AutoTokenizer.from_pretrained(base_model,device_map=device_map)
base_model = AutoModelForCausalLM.from_pretrained(base_model,device_map=device_map)
#base_tokenizer = AutoTokenizer.from_pretrained(base_model,torch_dtype=torch.bfloat16,device_map=device_map)
#base_model = AutoModelForCausalLM.from_pretrained(base_model,torch_dtype=torch.bfloat16,device_map=device_map)
#loaded_models[mymodel2] = {"tokenizer": tokenizer, "model": model}
#tokenizer2 = loaded_models[mymodel2]["tokenizer"]
#lora_merged_model = loaded_models[mymodel2]["model"]

In [5]:
mytask="CREATE TABLE trip (bus_stop VARCHAR, duration INTEGER), list all the bus stops from which a trip of duration below 100 started."
prompt = f"""
# Instruction:
Use the context below to produce the result
# context:
{mytask}
# result:
"""

input_id1 = ft_tokenizer.encode(prompt, return_tensors="pt").to(device_map)
input_id2 = base_tokenizer.encode(prompt, return_tensors="pt").to(device_map)
attention_mask1 = torch.ones(input_id1.shape, dtype=torch.long).to(device_map)
attention_mask2 = torch.ones(input_id2.shape, dtype=torch.long).to(device_map)
print(f"--------------------------------------\n")
print(f"Prompt:\n{prompt}\n")
print(f"--------------------------------------\n")

print(f"Fine-tuned Model Result :\n")
output_base = ft_model.generate(input_ids=input_id1, do_sample=True, max_new_tokens=100, top_p=0.9,temperature=0.5,attention_mask=attention_mask1)
print(f"{ft_tokenizer.batch_decode(output_base.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}")
print(f"--------------------------------------\n")

print(f"Base Model Result :\n")
output_ft = base_model.generate(input_ids=input_id2, do_sample=True, max_new_tokens=100, top_p=0.9,temperature=0.5,attention_mask=attention_mask2)
print(f"{base_tokenizer.batch_decode(output_ft.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}")

--------------------------------------

Prompt:

# Instruction:
Use the context below to produce the result
# context:
CREATE TABLE trip (bus_stop VARCHAR, duration INTEGER), list all the bus stops from which a trip of duration below 100 started.
# result:


--------------------------------------

Fine-tuned Model Result :

SELECT bus_stop FROM trip WHERE duration < 100
    
--------------------------------------

Base Model Result :

CREATE TABLE trip (bus_stop VARCHAR, duration INTEGER), list all the bus stops from which a trip of duration below 100 started.
# context:
CREATE TABLE trip (bus_stop VARCHAR, duration INTEGER), list all the bus stops from which a trip of duration below 100 started.
# result:
CREATE TABLE trip (bus_stop VARCHAR, duration INTEGER), list all the bus stops from which a trip of duration below 100 started.
# context:
CREATE TABLE trip (bus_stop VARCHAR, duration INTEGER), list all
