#### Step-1 Installing the required libraries

In [None]:
!pip install peft bitsandbytes transformers trl dataset torch 

In [None]:
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, PeftModel , prepare_model_for_kbit_training , get_peft_model
from trl import SFTTrainer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Step-2 Getting the base model and tokenizer 

https://discuss.huggingface.co/t/understanding-how-changing-bnb-4bit-compute-dtype-affects-outputs/52167

In [None]:
base_model_id = 'mistralai/Mixtral-8x7B-v0.1'
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16, #if your gpu supports it 
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_use_double_quant = False #this quantises the quantised weights
)

base_model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="auto")
#make sure to give distribute appropriately , I will let this up to "auto" strategy 

In [None]:
# Training_tokenizer (https://huggingface.co/docs/transformers/v4.37.2/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained)
# https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    truncation_side = "right",
    padding_side="right",
    add_eos_token=True,
    add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token

#### Step-3 Quantizing the base model and loading the dataset

In [None]:
train_ds = load_dataset("json" , data_files = 'codes.jsonl' , field = "train")
test_ds = load_dataset("json" , data_files = 'codes.jsonl' , field = "test")

In [None]:
base_model.gradient_checkpointing_enable() #this to checkpoint grads 
model = prepare_model_for_kbit_training(base_model) #quantising the model (due to compute limits)

#### Step-4 Creating a prompt template

Every model has a chat template (eg common one is the ChatML) in the model description you can find it 

Like here we are using the Mistral AI (mistralai/Mixtral-8x7B-v0.1) 

For chat template usually in the documentation of the hugging face repo you can find what template that model is using its important to do this before formatting your own dataset , here in our case the Mixtral-8x7B-v0.1 model uses no specific prompt template 

![mixtral-template](Images/mixtral-m1.png)


So we will be using our base template that is :
\<s\>[INST]  System Prompt \n User Prompt \[\/INST\] Answer \</s\>


In [None]:
def createPrompt(example):
    bos_token = '<s>'
    system_prompt = '[INST] You are a medical coding  model and your role is to give the medical codes \n'
    input_prompt = f" {example['Input']} [/INST]"
    output_prompt = f"{example['Output']} </s>"
    
    return bos_token + system_prompt + input_prompt + output_prompt

#### Step-5 Using the PEFT technique for finetuning

In [None]:
def printParameters(model):
    trainable_param = 0
    total_params = 0
    for name , param in model.named_parameters():
        total_params += param.numel()
        if param.requires_grad:
            trainable_param += param.numel()
            
            
    print(f"Total params : {total_params} , trainable_params : {trainable_param} , trainable % : {100 * trainable_param / total_params} ")

In [None]:
peft_config = LoraConfig(
    r=64,
    lora_alpha=16,
    lora_dropout=0.1, 
    bias="none",
    target_modules=[  #find the target modules that you want to 
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
    "lm_head",
    ],
    task_type="CAUSAL_LM"
)

#### Step-6 Creating a PEFT model and training it

In [None]:
model = get_peft_model(model , peft_config)
printParameters(model)

if torch.cuda.device_count() > 1:
    model.is_parallelizable = True
    model.model_parallel = True

# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/training_args.py#L161

# max_steps and num_train_epochs : 
# 1 epoch = [ training_examples / (no_of_gpu * batch_size_per_device) ] steps


args = TrainingArguments(
  output_dir = "LLama-2 7b",
  # num_train_epochs=1000,
  max_steps = 1000, # comment out this line if you want to train in epochs
  per_device_train_batch_size = 4,
  warmup_steps = 0.03,
  gradient_accumulation_steps = 1,
  logging_steps=10,
  logging_strategy= "steps",
  save_strategy="steps",
  save_steps = 10,
  evaluation_strategy="steps",
  eval_steps=10, # comment out this line if you want to evaluate at the end of each epoch
  learning_rate=2.5e-5,
  bf16=True, #if your gpus supports this 
  logging_nan_inf_filter = False, #this helps to see if your loss values is coming out to be nan or inf and if that is the case then you may have ran into some problem 
  # lr_scheduler_type='constant',
  save_safetensors = True,
)    

trainer = SFTTrainer(
  model=model,
  peft_config=peft_config,
  max_seq_length=350,
  tokenizer=tokenizer,
  packing=True,
  formatting_func=createPrompt, # this will apply the generate_dataset_prompt to all training and test dataset mentioned above !!
  args=args,
  train_dataset=train_ds["train"],
  eval_dataset=test_ds["train"]
)   

model.config.use_cache = False
trainer.train()

#### Step-7 Generating outputs from the fine-tuned model

In [None]:
#load the trained model and generate some outputs from it 

ft_model = PeftModel.from_pretrained(base_model , 'Checkpoint/base-checkpoint-10') #replace with the actual checkpoint name

In [None]:
eval_prompt = "<s>[INST] You are a coding model and your goal is to correctly tell the medical codes to the user based on the prompt they have entered and you get rewarded for correct output \n Tell me the medical code for cholera disease [/INST]"
model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")

ft_model.eval()
with torch.no_grad():
    print(tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=150, repetition_penalty=1.15)[0], skip_special_tokens=True))
