<a href="https://colab.research.google.com/github/ffreemt/app1/blob/master/succ_mzwc_SFTTrainer_Fine_Tuning_Llama_2_merge_OOM_medical_apply_chat_template_Dekivadiya.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tuning Llama 2: Your Path to Medical Text Perfection
https://medium.com/@kevaldekivadiya2415/fine-tuning-llama-2-your-path-to-chemistry-text-perfection-aa4c54ff5790
- https://github.com/kevaldekivadiya2415/LLMs-Fine-Tuning

In [None]:
from IPython.display import clear_output

In [None]:
import rich
import torch
if not torch.cuda.is_available():
  rich.print("[red bold]No GPU. Turn on GPU and try again[/red bold]")
  raise SystemExit("Turn on GPU and try again")
else:
  rich.print("[green bold]GPU present, good to go[/green bold]")

In [None]:
!pip install -q transformers datasets peft accelerate bitsandbytes trl safetensors torch --no-cache
clear_output()
!date

Sat Jan  6 03:49:31 AM UTC 2024


## load_dataset: *medalpaca/medical_meadow_medqa*

In [None]:
from datasets import load_dataset
from random import randrange

# Load dataset from the hub
dataset = load_dataset("medalpaca/medical_meadow_medqa", split="train")

print(f"Dataset Size: {len(dataset)}")
print(dataset[randrange(len(dataset))])
# Dataset Size: 10178

Dataset Size: 10178
{'instruction': 'Please answer with one of the option in the bracket', 'input': "Q:An 81-year-old man is brought to the physician by his daughter after he was found wandering on the street. For the last 3 months, he often has a blank stare for several minutes. He also claims to have seen strangers in the house on several occasions who were not present. He has hypertension and hyperlipidemia, and was diagnosed with Parkinson disease 8 months ago. His current medications include carbidopa-levodopa, hydrochlorothiazide, and atorvastatin. His blood pressure is 150/85 mm Hg. He has short-term memory deficits and appears confused and disheveled. Examination shows bilateral muscle rigidity and resting tremor in his upper extremities. He has a slow gait with short steps. Microscopic examination of the cortex of a patient with the same condition is shown. Which of the following is the most likely diagnosis?? \n{'A': 'Lewy body dementia', 'B': 'Creutzfeldt-Jakob disease', 'C'

# format_prompt


In [None]:
!pip install -q python-box

from box import Box

In [None]:
_ = """
prompt_tempalte:
  Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

  ### Instruction:
  {sample["instruction"]}

  ### Input:
  {sample["input"]}

  ### Response:
  {sample["output"]}
"""

# _
# Box.from_yaml(_)

In [None]:
def format_prompt(sample):
    return f"""
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{sample["instruction"]}

### Input:
{sample["input"]}

### Response:
{sample["output"]}
"""

In [None]:
from random import randrange

_ = randrange(len(dataset))

print(dataset[_], format_prompt(dataset[_]))

{'instruction': 'Please answer with one of the option in the bracket', 'input': "Q:A 25-year-old male wrestler presents to his primary care physician for knee pain. He was in a wrestling match yesterday when he was abruptly taken down. Since then, he has had pain in his left knee. The patient states that at times it feels as if his knee locks as he moves it. The patient has a past medical history of anabolic steroid abuse; however, he claims to no longer be using them. His current medications include NSAIDs as needed for minor injuries from participating in sports. On physical exam, you note medial joint tenderness of the patient’s left knee, as well as some erythema and bruising. The patient has an antalgic gait as you observe him walking. Passive range of motion reveals a subtle clicking of the joint. There is absent anterior displacement of the tibia relative to the femur on an anterior drawer test. The rest of the physical exam, including examination of the contralateral knee is wi

# Fine-tune Llama 2 using trl and the SFTTrainer

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Hugging Face model name
model_name = "meta-llama/Llama-2-7b-chat-hf"
use_flash_attention = False

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    use_cache=False,
    use_flash_attention_2=use_flash_attention,
    device_map="auto",
    torch_dtype=torch.float16
)

model.config.pretraining_tp = 1

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# LoraConfig

In [None]:
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

# LoRA config based on QLoRA paper
peft_config = LoraConfig(
    lora_alpha=32,
    lora_dropout=0.1,
    r=16,
    bias="none",
    task_type="CAUSAL_LM",
)
# Prepare model for training
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

## TrainingArguments

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="finetuned-llama-7b-chat-hf-med",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    logging_steps=10,
    save_strategy="epoch",
    learning_rate=2e-4,
    fp16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    disable_tqdm=False,
)

# SFTTrainer

In [None]:
from trl import SFTTrainer

max_seq_length = 1024 # max sequence length for model and packing of the dataset

trainer = SFTTrainer(
    model=model,
    # train_dataset=dataset
    train_dataset=dataset.select(range(len(dataset) // 20)),
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=format_prompt,
    args=args,
)

# full dataset eta: 11 hrs

# int(len(dataset)/20), dataset.select(range(10))

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
 dataset.select(range(len(dataset) // 20))

Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 508
})

In [None]:
# len(dataset) 10178： eta 11 hrs
# int(len(dataset)/20) 508：31 min

10178

## Start trainer: eta 11 hrs!

In [None]:
# Train
trainer.train()

# full dataset eta: 11 hrs
# loss 1.42...(7)...0.96

# 508
# loss：0.91 0.92 0.88 0.824300

Step,Training Loss
10,0.9191
20,0.9274
30,0.8873
40,0.8592
50,0.8651
60,0.8243




TrainOutput(global_step=66, training_loss=0.87595196203752, metrics={'train_runtime': 1988.0042, 'train_samples_per_second': 0.264, 'train_steps_per_second': 0.033, 'total_flos': 2.13397058617344e+16, 'train_loss': 0.87595196203752, 'epoch': 3.0})

In [None]:
!du -sh finetuned-llama-7b-chat-hf-med/checkpoint-66/*

4.0K	finetuned-llama-7b-chat-hf-med/checkpoint-66/adapter_config.json
33M	finetuned-llama-7b-chat-hf-med/checkpoint-66/adapter_model.safetensors
65M	finetuned-llama-7b-chat-hf-med/checkpoint-66/optimizer.pt
8.0K	finetuned-llama-7b-chat-hf-med/checkpoint-66/README.md
16K	finetuned-llama-7b-chat-hf-med/checkpoint-66/rng_state.pth
4.0K	finetuned-llama-7b-chat-hf-med/checkpoint-66/scheduler.pt
4.0K	finetuned-llama-7b-chat-hf-med/checkpoint-66/special_tokens_map.json
4.0K	finetuned-llama-7b-chat-hf-med/checkpoint-66/tokenizer_config.json
1.8M	finetuned-llama-7b-chat-hf-med/checkpoint-66/tokenizer.json
4.0K	finetuned-llama-7b-chat-hf-med/checkpoint-66/trainer_state.json
8.0K	finetuned-llama-7b-chat-hf-med/checkpoint-66/training_args.bin


In [None]:
# Save model
trainer.save_model()


In [None]:
args.output_dir

'finetuned-llama-7b-chat-hf-med'

## Model Testing and Inference

In [None]:
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

# Load finetuned LLM model and tokenizer
model = AutoPeftModelForCausalLM.from_pretrained(
    args.output_dir,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)

In [None]:
from datasets import load_dataset
from random import randrange

# Load dataset from the hub
dataset = load_dataset("medalpaca/medical_meadow_medqa", split="train")
sample = dataset[randrange(len(dataset))]

prompt = f"""
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{sample["instruction"]}

### Input:
{sample["input"]}

### Response:
"""

input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
input_ids

tensor([[    1, 29871,    13, 21140,   340,   338,   385, 15278,   393, 16612,
           263,  3414, 29892,  3300,  2859,   411,   385,  1881,   393,  8128,
          4340,  3030, 29889, 14350,   263,  2933,   393,  7128,  2486,  1614,
          2167,   278,  2009, 29889,    13,    13,  2277, 29937,  2799,  4080,
         29901,    13, 12148,  1234,   411,   697,   310,   278,  2984,   297,
           278,  4105,  3522,    13,    13,  2277, 29937, 10567, 29901,    13,
         29984, 29901,  2744, 29871, 29896, 29896, 29899,  6360, 29899,  1025,
          8023,   338,  6296,   304,   278, 11176, 14703, 14311,   491,   670,
         11825,   411,   263, 29871, 29906, 29899,  3250,  4955,   310,  1238,
           369, 29892,   286,  2883,   895, 29892,   322,  3234,   573,   274,
           820, 29889,  1551, 24329, 29892,   540,   338,  1476,   304,   367,
          1407,  8062,   322,   338,  2534, 14656, 16172,   292, 29889,  3600,
          4940, 16083,  4955,   338,  7282,   363,  

In [None]:
print(prompt)


Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Please answer with one of the option in the bracket

### Input:
Q:An 11-year-old boy is brought to the emergency department by his parents with a 2-day history of fever, malaise, and productive cough. On presentation, he is found to be very weak and is having difficulty breathing. His past medical history is significant for multiple prior infections requiring hospitalization including otitis media, upper respiratory infections, pneumonia, and sinusitis. His family history is also significant for a maternal uncle who died of an infection as a child. Lab findings include decreased levels of IgG, IgM, IgA, and plasma cells with normal levels of CD4 positive cells. The protein that is most likely defective in this patient has which of the following functions?? 
{'A': 'Actin polymerization', 'B': 'Autoimmune regulati

In [None]:
outputs = model.generate(input_ids=input_ids, max_new_tokens=512, do_sample=True, top_p=0.6,temperature=0.9)

print(f"Instruction:\n{sample['instruction']}\n")
print(f"Input:\n{sample['input']}\n")
print(f"Generated Response:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}\n")
print(f"Ground Truth:\n{sample['output']}")



Instruction:
Please answer with one of the option in the bracket

Input:
Q:An 11-year-old boy is brought to the emergency department by his parents with a 2-day history of fever, malaise, and productive cough. On presentation, he is found to be very weak and is having difficulty breathing. His past medical history is significant for multiple prior infections requiring hospitalization including otitis media, upper respiratory infections, pneumonia, and sinusitis. His family history is also significant for a maternal uncle who died of an infection as a child. Lab findings include decreased levels of IgG, IgM, IgA, and plasma cells with normal levels of CD4 positive cells. The protein that is most likely defective in this patient has which of the following functions?? 
{'A': 'Actin polymerization', 'B': 'Autoimmune regulation', 'C': 'Lysosomal trafficking', 'D': 'Nucleotide salvage', 'E': 'Protein phosphorylation'},

Generated Response:
E: Protein phosphorylation


Ground Truth:
E: Protei

# Merge and save

RAM **OOM**: reuire 14GB RAM?

In [None]:
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(
    # args.output_dir,
    "finetuned-llama-7b-chat-hf-med",
    low_cpu_mem_usage=True,
)

# Merge LoRA and base model
merged_model = model.merge_and_unload()

# Save the merged model
merged_model.save_pretrained("merged-llama2-7b-chat-hf-med", safe_serialization=True)
tokenizer.save_pretrained("merged-llama2-7b-chat-hf-med")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
!ls merged-llama2-7b-chat-hf-med

## from comments ( tokenizer.apply_chat_template)

In [None]:
# import os
# from google.colab import userdata
# userdata.get('HF_TOKEN')

# os.environ.update({
#     "HF_TOKEN": userdata.get('HF_TOKEN')
# })

# os.getenv("HF_TOKEN")

In [None]:
import os
assert os.getenv("HF_TOKEN"), "For gated repo, you must set env var HF_TOKEN or execute huggingface.login('your-hf-token')"

In [None]:
import locale
locale.getpreferredencoding = lambda x=True: "UTF-8"


In [None]:
o = !huggingface-cli whoami
o[0].lower().count("not")

In [None]:
o

In [None]:
from transformers import AutoTokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")


In [None]:
messages = [
  {'role': 'system', 'content': 'You are a chatbot.'},
  {'role': 'user', 'content': 'I am Jack'},
   {'role': 'system', 'content': 'Hi Jack.'},
  {'role': 'user', 'content': 'How are you?'}
]

In [None]:
tokenizer.apply_chat_template?

tokenizer.apply_chat_template(messages)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import torch

# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, "float16")
quantization_config = BitsAndBytesConfig(
  load_in_8bit=True,
  bnb_4bit_quant_type="nf4", # normal float4 (QLora: https://arxiv.org/pdf/2305.14314.pdf)
  bnb_4bit_compute_dtype=compute_dtype,
  bnb_4bit_use_double_quant=False,
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", quantization_config=quantization_config)
messages = [
{'role': 'system', 'content': 'You are a chatbot.'},
{'role': 'user', 'content': 'How are you?'}
]

pipe = pipeline("conversational", model=model, tokenizer=tokenizer)

r = pipe(messages)

print(r)