# Finetuning a Gemma model for Q & A

https://medium.com/the-ai-forum/instruction-fine-tuning-gemma-2b-on-medical-reasoning-and-convert-the-finetuned-model-into-gguf-844191f8d329

## Constants

In [2]:
INPUT_LIMIT = 1024
TRAIN_SIZE = 1000
TEST_SIZE = 100
EVAL_SIZE = 50
SEED = 123

BASE_MODEL_ID = "google/gemma-2b-it"
# BASE_MODEL_ID = "google/gemma-2-9b-it"
GENERATE_KWARGS = dict(
    do_sample=True,
    max_new_tokens=512,
    temperature=1e-3,
)
EVAL_BATCH_SIZE = 4

TRAIN_MAX_LENGTH = 512
TRAIN_NUM_EPOCHS = 1
TRAIN_BATCH_SIZE = 4
TRAIN_GRADIENT_ACCUMULATION_STEPS = 1
TRAIN_LOGGING_STEPS = 10
EVAL_ACCUMULATION_STEPS = 4

## Load Gemma 2B instruct model

https://huggingface.co/google/gemma-2b-it

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_storage=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    quantization_config=quantization_config
)
print(model.device)

  _torch_pytree._register_pytree_node(


OSError: You are trying to access a gated repo.
Make sure to request access at https://huggingface.co/google/gemma-2b-it and pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`.

## Medical Q & A

https://huggingface.co/datasets/medalpaca/medical_meadow_medqa

In [3]:
from medqa_data import load_train_test_data

dataset = load_train_test_data(
    train_size=TRAIN_SIZE,
    test_size=TEST_SIZE,
    seed=SEED,
    input_limit=INPUT_LIMIT,
)
display(dataset)

def print_sample(sample: dict[str, str]):
    print("\n".join(f"\n# {k}\n{v}" for k, v in sample.items())[1:])

print_sample(dataset["test"][0])

DatasetDict({
    train: Dataset({
        features: ['input', 'output'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['input', 'output'],
        num_rows: 100
    })
})

# input
Q:A 45-year-old man comes to the physician because of worsening shortness of breath and dry cough for 6 months. The patient's symptoms get worse when he walks more than about 150 yards. He also reports fatigue and difficulty swallowing solid foods. In cold weather, his fingers occasionally turn blue and become painful. He occasionally smokes cigarettes on weekends. His temperature is 37°C (98.6°F), and respirations are 22/min, pulse is 87/min, and blood pressure is 126/85 mm Hg. The skin over his trunk and arms is thickened and tightened. Fine inspiratory crackles are heard over bilateral lower lung fields on auscultation. Which of the following additional findings is most likely in this patient?? 
{'A': 'Decreased right atrial pressure', 'B': 'Increased lung compliance', 'C': 'Decreased diffusing capacity', 'D': 'Increased airway resistance', 'E': 'Decreased A-a gradient'},

# output
C: Decreased diffusing capacity


In [4]:
question = """
Q:A child is in the nursery one day after birth. A nurse notices a urine-like discharge being expressed through the umbilical stump. What two structures in the embryo are connected by the structure that failed to obliterate during the embryologic development of this child??
{'A': 'Pulmonary artery - aorta', 'B': 'Bladder - yolk sac', 'C': 'Bladder - small bowel', 'D': 'Liver - umbilical vein', 'E': 'Kidney - large bowel'},
Give your answer as a JSON dictionary with the "option" (a letter from A-E) and the  corresponding"option_text". No yapping.
""".strip()

chat = [{"role": "user", "content": question}]
input_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt").to(model.device)
output_ids = model.generate(input_ids, **GENERATE_KWARGS)
input_ids = input_ids.to("cpu")
output_ids = output_ids.to("cpu")
print(tokenizer.decode(output_ids[0]))

<bos><start_of_turn>user
Q:A child is in the nursery one day after birth. A nurse notices a urine-like discharge being expressed through the umbilical stump. What two structures in the embryo are connected by the structure that failed to obliterate during the embryologic development of this child??
{'A': 'Pulmonary artery - aorta', 'B': 'Bladder - yolk sac', 'C': 'Bladder - small bowel', 'D': 'Liver - umbilical vein', 'E': 'Kidney - large bowel'},
Give your answer as a JSON dictionary with the "option" (a letter from A-E) and the  corresponding"option_text". No yapping.<end_of_turn>
<start_of_turn>model
The correct answer is **B**: The bladder and the yolk sac.

The bladder and the yolk sac are connected by the structure that failed to obliterate during the embryologic development of this child.<eos>


In [5]:
from medqa_data import reformat_sample

dataset = dataset.map(reformat_sample)
train_data, test_data = dataset["train"], dataset["test"]

print_sample(test_data[0])

# input
Q: A 45-year-old man comes to the physician because of worsening shortness of breath and dry cough for 6 months. The patient's symptoms get worse when he walks more than about 150 yards. He also reports fatigue and difficulty swallowing solid foods. In cold weather, his fingers occasionally turn blue and become painful. He occasionally smokes cigarettes on weekends. His temperature is 37°C (98.6°F), and respirations are 22/min, pulse is 87/min, and blood pressure is 126/85 mm Hg. The skin over his trunk and arms is thickened and tightened. Fine inspiratory crackles are heard over bilateral lower lung fields on auscultation. Which of the following additional findings is most likely in this patient?? 
{'A': 'Decreased right atrial pressure', 'B': 'Increased lung compliance', 'C': 'Decreased diffusing capacity', 'D': 'Increased airway resistance', 'E': 'Decreased A-a gradient'}
Give your answer as a JSON dictionary with the "option" (a letter from A-E) and the  corresponding"optio

In [6]:
import numpy as np
from medqa_data import create_predict

batch_predict = create_predict(tokenizer, model, "gemma", batch=True, generate_kwargs=GENERATE_KWARGS)
test_data = test_data.map(batch_predict, batched=True, batch_size=EVAL_BATCH_SIZE)

gemma_accuracy = (np.asarray(test_data["gemma_label"]) == np.asarray(test_data["true_label"])).mean()
print(f"Gemma accuracy: {round(gemma_accuracy * 100, 1)}%")



Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Gemma accuracy: 19.0%


## Supervised finetuning (SFT)

https://huggingface.co/docs/trl/en/sft_trainer
https://huggingface.co/docs/bitsandbytes/main/en/fsdp_qlora

In [7]:
def create_chat_for_finetuning(samples: dict) -> list[str]:
    chat_texts = tokenizer.apply_chat_template(
        [
            [
                {"role": "user", "content": input},
                {"role": "assistant", "content": output},
            ]
            for input, output in zip(samples["input"], samples["output"])
        ],
        tokenize=False,
    )
    return [text.removeprefix("<bos>") for text in chat_texts]

display(create_chat_for_finetuning(train_data[:3]))

['<start_of_turn>user\nQ: A 63-year-old man with a history of hypertension and atrial fibrillation is brought into the emergency room and found to have a ventricular tachyarrhythmia. Ibutilide is discontinued and the patient is switched to another drug that also prolongs the QT interval but is associated with a decreased risk of torsades de pointes. Which drug was most likely administered in this patient?? \n{\'A\': \'Sotalol\', \'B\': \'Digoxin\', \'C\': \'Esmolol\', \'D\': \'Amiodarone\', \'E\': \'Quinidine\'}\nGive your answer as a JSON dictionary with the "option" (a letter from A-E) and the  corresponding"option_text". No yapping.<end_of_turn>\n<start_of_turn>model\n{"option": "D", "text": "Amiodarone"}<end_of_turn>\n',
 '<start_of_turn>user\nQ: An investigator is studying a drug that acts on the thyroid hormone pathway. Levels of serum free T3 and T4 in healthy participants are measured before and after administration of the drug. After administration, there is a decrease in the 

In [8]:
# pip install peft
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules="all-linear",
)

model = get_peft_model(model, lora_config)

trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable:,} | total: {total:,} | Percentage: {trainable/total*100:.4f}%")

Trainable: 78,446,592 | total: 2,584,619,008 | Percentage: 3.0351%


In [9]:
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer

tokenizer.padding_side = "right"
model.config.use_cache=False
torch.cuda.empty_cache()

collator = DataCollatorForCompletionOnlyLM(
    instruction_template="<start_of_turn>user\n",
    response_template="<start_of_turn>model\n",
    tokenizer=tokenizer,
    mlm=False,
)

trainer = SFTTrainer(
    model,
    args=SFTConfig(
        output_dir="/tmp/finetuned_gemma_2b",
        per_device_train_batch_size=TRAIN_BATCH_SIZE,
        gradient_accumulation_steps=TRAIN_GRADIENT_ACCUMULATION_STEPS,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs=dict(use_reentrant=False),
        max_seq_length=TRAIN_MAX_LENGTH,
        num_train_epochs=TRAIN_NUM_EPOCHS,
        save_strategy="epoch",
        logging_steps=TRAIN_LOGGING_STEPS,
        eval_steps=TRAIN_LOGGING_STEPS,
        eval_strategy="steps",
        eval_accumulation_steps=EVAL_ACCUMULATION_STEPS,
    ),
    data_collator=collator,
    eval_dataset=test_data.select(range(min(len(test_data), EVAL_SIZE))),
    formatting_func=create_chat_for_finetuning,
    peft_config=lora_config,
    train_dataset=train_data,
    tokenizer=tokenizer,
)

train_result = trainer.train()
display(train_result._asdict())

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Step,Training Loss,Validation Loss
10,2.7278,1.500449
20,1.2586,0.914189
30,0.7819,0.592388
40,0.5424,0.391848
50,0.3117,0.260579
60,0.214,0.204531
70,0.2163,0.163737
80,0.1397,0.149012
90,0.1125,0.14239
100,0.157,0.138367


{'global_step': 250,
 'training_loss': 0.33258164548873903,
 'metrics': {'train_runtime': 527.8125,
  'train_samples_per_second': 1.895,
  'train_steps_per_second': 0.474,
  'total_flos': 3622111360008192.0,
  'train_loss': 0.33258164548873903,
  'epoch': 1.0}}

In [10]:
model.config.use_cache=True
model.gradient_checkpointing_disable()
model.eval()
tokenizer.padding_side = "left"

## Merge and evaluate the finetuned model

After training the LoRA adapter has not yet been merged with the base Gemma model. This will make it run a lot slower. To merge the LoRA adapter, we will follow these steps:
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.merge_and_unload
- https://discuss.huggingface.co/t/help-with-merging-lora-weights-back-into-base-model/40968/3

In [11]:
from peft import PeftModel

trainer.model.save_pretrained("models/lora_adapter")

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
)

model = PeftModel.from_pretrained(base_model, "models/lora_adapter").merge_and_unload()
model.save_pretrained("models/finetuned_model", safe_serialization=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Now let's clean up all our prior models from GPU memory, and load the merged model.

In [14]:
import gc

try: del model
except NameError: pass
try: del trainer
except NameError: pass
try: del base_model
except NameError: pass

gc.collect()
torch.cuda.empty_cache()

model = AutoModelForCausalLM.from_pretrained(
    "models/finetuned_model",
    quantization_config=quantization_config
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [15]:
batch_predict = create_predict(tokenizer, model, "finetuned", batch=True, generate_kwargs=GENERATE_KWARGS)
test_data = test_data.map(batch_predict, batched=True, batch_size=EVAL_BATCH_SIZE)

finetuned_accuracy = (np.asarray(test_data["finetuned_label"]) == np.asarray(test_data["true_label"])).mean()
print(f"Gemma accuracy: {round(gemma_accuracy * 100, 1)}%")
print(f"Finetuned accuracy: {round(finetuned_accuracy * 100, 1)}%")

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Gemma accuracy: 19.0%
Finetuned accuracy: 25.0%
