# LLM fine tuning

### train, val, test split

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from unsloth import FastModel
import pandas as pd
import numpy as np
from tqdm import tqdm 
from src.helper import save_jsonl, rouge_lsum

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [3]:
df = pd.read_csv('data/clean_df')

In [4]:
# train test split to have at least 1000 on val and test
n = len(df)
indices = np.random.permutation(n)

train_end = n - 2000
val_end = n - 1000

train_idx = indices[:train_end]
val_idx = indices[train_end:val_end]
test_idx = indices[val_end:]

df_train = df.iloc[train_idx].reset_index(drop=True)
df_val = df.iloc[val_idx].reset_index(drop=True)
df_test = df.iloc[test_idx].reset_index(drop=True)

save_jsonl(df_train, 'data/train.jsonl')
save_jsonl(df_val, 'data/val.jsonl')
save_jsonl(df_test, 'data/test.jsonl')
len(df_train), len(df_val), len(df_test)

(13989, 1000, 1000)

In [5]:
from datasets import load_dataset

dataset = load_dataset(
    'json', 
    data_files={
        'train': 'data/train.jsonl',
        'val': 'data/val.jsonl',
        # 'test': 'data/test.jsonl'
    }
)


Generating train split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

# fine tune 
- gemma3:270m-it

In [6]:
from src.llm_helper import max_seq_length

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-270m-it",
    max_seq_length = max_seq_length,
    load_in_4bit = False,
    load_in_8bit = True, 
    full_finetuning = True,    
)

# model = FastModel.get_peft_model(
#     model,
#     r = 32, 
#     target_modules = [
#         "q_proj", "k_proj", "v_proj", "o_proj",
#         "gate_proj", "up_proj", "down_proj"
#     ],
#     lora_alpha = 32,
#     lora_dropout = 0, 
#     bias = "none",    
#     use_gradient_checkpointing = "unsloth", 
#     random_state = 3407,
#     use_rslora = False,
#     loftq_config = None,
# )

Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.
==((====))==  Unsloth 2025.8.9: Fast Gemma3 patching. Transformers: 4.55.4.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 23.988 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using bfloat16 full finetuning which cuts memory usage by 50%.


In [7]:
from src.llm_helper import system_prompt
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma3",
)

In [8]:

def convert_to_chatml(example):
    return {
        "conversations": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": example["instruction"]},
            {"role": "assistant", "content": example["output"]}
        ]
    }

dataset = dataset.map(
    convert_to_chatml
)

dataset['train']['conversations'][:2]

Map:   0%|          | 0/13989 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

[[{'content': 'You are an expert specialized in medicine. Your role is to provide precise, factual, and immediate information without any superfluous language, greetings, or acknowledgments.',
   'role': 'system'},
  {'content': 'How to diagnose Salivary Gland Cancer ?', 'role': 'user'},
  {'content': "Tests that examine the head, neck, and the inside of the mouth are used to detect (find) and diagnose salivary gland cancer. The following procedures may be used:         -  Physical exam and history : An exam of the body to check general signs of health. The head, neck, mouth, and throat will be checked for signs of disease, such as lumps or anything else that seems unusual. A history of the patient's health habits and past illnesses and treatments will also be taken.    -   MRI (magnetic resonance imaging): A procedure that uses a magnet, radio waves, and a computer to make a series of detailed pictures of areas inside the body. This procedure is also called nuclear magnetic resonance 

In [9]:

def formatting_prompts_func(examples):
    "to apply chat template like <bos> and other terms developed by unsloth"
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
    return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True)
dataset['train']['text'][:2]

Map:   0%|          | 0/13989 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

["<start_of_turn>user\nYou are an expert specialized in medicine. Your role is to provide precise, factual, and immediate information without any superfluous language, greetings, or acknowledgments.\n\nHow to diagnose Salivary Gland Cancer ?<end_of_turn>\n<start_of_turn>model\nTests that examine the head, neck, and the inside of the mouth are used to detect (find) and diagnose salivary gland cancer. The following procedures may be used:         -  Physical exam and history : An exam of the body to check general signs of health. The head, neck, mouth, and throat will be checked for signs of disease, such as lumps or anything else that seems unusual. A history of the patient's health habits and past illnesses and treatments will also be taken.    -   MRI (magnetic resonance imaging): A procedure that uses a magnet, radio waves, and a computer to make a series of detailed pictures of areas inside the body. This procedure is also called nuclear magnetic resonance imaging (NMRI).    -   CT 

### training

In [11]:
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset['train'],
    eval_dataset = dataset['val'], 
    args = SFTConfig(
        eval_strategy = "steps",
        do_eval = True,
        dataset_text_field = "text",
        per_device_train_batch_size = 32,
        gradient_accumulation_steps = 1,
        warmup_steps = 5,
        num_train_epochs = 4,
        # max_steps = 500,
        learning_rate = 1e-5, 
        logging_steps = 100,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir="outputs",
        report_to = "none", 
        eval_steps=200,
    ),
)

In [12]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

Map (num_proc=20):   0%|          | 0/13989 [00:00<?, ? examples/s]

Map (num_proc=20):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [13]:
print(tokenizer.decode(trainer.train_dataset[100]["input_ids"]))

<bos><start_of_turn>user
You are an expert specialized in medicine. Your role is to provide precise, factual, and immediate information without any superfluous language, greetings, or acknowledgments.

What are the treatments for Warsaw breakage syndrome ?<end_of_turn>
<start_of_turn>model
These resources address the diagnosis or management of Warsaw breakage syndrome:  - Centers for Disease Control and Prevention: Hearing Loss in Children  - Genetic Testing Registry: Warsaw breakage syndrome  - MedlinePlus Encyclopedia: Hearing Loss--Infants   These resources from MedlinePlus offer information about the diagnosis and management of various health conditions:  - Diagnostic Tests  - Drug Therapy  - Surgery and Rehabilitation  - Genetic Counseling   - Palliative Care<end_of_turn>



In [14]:
# print masked output (only answer to be shown)
print(tokenizer.decode(
    [tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]
).replace(tokenizer.pad_token, " "))

                                                  These resources address the diagnosis or management of Warsaw breakage syndrome:  - Centers for Disease Control and Prevention: Hearing Loss in Children  - Genetic Testing Registry: Warsaw breakage syndrome  - MedlinePlus Encyclopedia: Hearing Loss--Infants   These resources from MedlinePlus offer information about the diagnosis and management of various health conditions:  - Diagnostic Tests  - Drug Therapy  - Surgery and Rehabilitation  - Genetic Counseling   - Palliative Care<end_of_turn>



In [15]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 13,989 | Num Epochs = 4 | Total steps = 1,752
O^O/ \_/ \    Batch size per device = 32 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (32 x 1 x 1) = 32
 "-____-"     Trainable parameters = 268,098,176 of 268,098,176 (100.00% trained)


Step,Training Loss,Validation Loss
200,1.8526,1.800714
400,1.7439,1.742477
600,1.7108,1.716741
800,1.7034,1.702155
1000,1.6839,1.693971
1200,1.6618,1.689878
1400,1.6664,1.688075
1600,1.6695,1.68756


Unsloth: Will smartly offload gradients to save VRAM!


In [16]:
# save only the lora adaptors
model.save_pretrained("saved_models/full_model2") 
tokenizer.save_pretrained("saved_models/full_model2")

('saved_models/full_model2/tokenizer_config.json',
 'saved_models/full_model2/special_tokens_map.json',
 'saved_models/full_model2/chat_template.jinja',
 'saved_models/full_model2/tokenizer.model',
 'saved_models/full_model2/added_tokens.json',
 'saved_models/full_model2/tokenizer.json')

In [None]:
trainer_stats

# evaluation

## base line model
- gemma3:270m

In [None]:
from src.llm_helper import invoke_llm


In [None]:
res = []
for example in tqdm(dataset['val'].select(range(2))):
    answer = invoke_llm(example['instruction'])
    res.append([rouge_lsum(answer, example['output']), [example['instruction'], answer]])


In [None]:
res