<a href="https://colab.research.google.com/github/hooman650/MedQwenReasoner/blob/main/MedQwen3B_Reasoner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MedQwen3B-Reasoner: Medical Reasoning Tutorial

Welcome to this Google Colab tutorial for training **MedQwen3B-Reasoner** - a specialized 3B-parameter language model optimized for medical domain reasoning and mathematical problem-solving. This guide will walk you through the process of fine-tuning and deploying a model that combines clinical expertise with structured reasoning capabilities.

**Key Tutorial Focuses**:
- 🏥 Leveraging GRPO (Group Relative Policy Optimization) for medical domain adaptation
- 📊 Curated training data blending PubMedQA (70%) with mathematical reasoning datasets
- 🧠 Implementing structured reasoning outputs with `<reasoning>`/`<answer>` formatting
- ⚡ Efficient deployment using 4-bit quantization via unsloth
- 🩺 Practical applications in clinical decision support and biomedical research analysis

In [1]:
%%capture
import sys; modules = list(sys.modules.keys())
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None

!pip install unsloth vllm
!pip install --upgrade pillow
# If you are running this notebook on local, you need to install `diffusers` too
# !pip install diffusers
# Temporarily install a specific TRL nightly version
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b
!pip install ipywidgets
!pip install diffusers

We will be using the amazing Unsloth library for this tutorial.

In [2]:
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

Unsloth: Patching Xformers to fix some performance issues.
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 02-10 19:11:20 __init__.py:190] Automatically detected platform cuda.


## Download and initialize the model
We will first download the model and leverage 50% of the GPU capacity along with vLLM inference to speed up the GRPO training using Qlora.

In [3]:
from unsloth import is_bfloat16_supported
import torch
max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.5, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

==((====))==  Unsloth 2025.2.5: Fast Qwen2 patching. Transformers: 4.48.2.
   \\   /|    GPU: Tesla T4. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/qwen2.5-3b-instruct-unsloth-bnb-4bit with actual GPU utilization = 49.66%
Unsloth: Your GPU has CUDA compute capability 7.5 with VRAM = 14.74 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2048. Num Sequences = 192.
Unsloth: vLLM's KV Cache can use up to 4.9 GB. Also swap space = 2 GB.
INFO 02-10 19:16:31 config.py:542] This model supports multiple tasks: {'embed', 'generate', 'reward', 'score', 'classify'}. Defaulting to 'generate'.
Unsloth: vLLM Bitsandbytes config using kwargs = {'load_i

tokenizer_config.json:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/271 [00:00<?, ?B/s]

INFO 02-10 19:16:35 cuda.py:179] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 02-10 19:16:35 cuda.py:227] Using XFormers backend.
INFO 02-10 19:16:36 model_runner.py:1110] Starting to load model unsloth/qwen2.5-3b-instruct-unsloth-bnb-4bit...
INFO 02-10 19:16:37 loader.py:1102] Loading weights with BitsAndBytes quantization.  May take a while ...
INFO 02-10 19:16:37 weight_utils.py:252] Using model weights format ['*.safetensors']


model.safetensors:   0%|          | 0.00/2.36G [00:00<?, ?B/s]

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 02-10 19:17:03 model_runner.py:1115] Loading model weights took 2.2160 GB
INFO 02-10 19:17:03 punica_selector.py:18] Using PunicaWrapperGPU.
INFO 02-10 19:17:18 worker.py:267] Memory profiling takes 13.86 seconds
INFO 02-10 19:17:18 worker.py:267] the current vLLM instance can use total_gpu_memory (14.74GiB) x gpu_memory_utilization (0.50) = 7.32GiB
INFO 02-10 19:17:18 worker.py:267] model weights take 2.22GiB; non_torch_memory takes 0.05GiB; PyTorch activation peak memory takes 1.05GiB; the rest of the memory reserved for KV Cache is 4.01GiB.
INFO 02-10 19:17:19 executor_base.py:110] # CUDA blocks: 7293, # CPU blocks: 3640
INFO 02-10 19:17:19 executor_base.py:115] Maximum concurrency for 2048 tokens per request: 56.98x
INFO 02-10 19:17:22 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error oc

Capturing CUDA graph shapes: 100%|██████████| 27/27 [01:02<00:00,  2.31s/it]

INFO 02-10 19:18:24 model_runner.py:1562] Graph capturing finished in 62 secs, took 0.62 GiB
INFO 02-10 19:18:24 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 80.50 seconds





config.json:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

Unsloth 2025.2.5 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


## Continual Pretraining

Now we go through the continual finetuning. We will be using three datasets from huggingface hub respectively. `openai/gsm8k` , `qiaojin/PubMedQA` and `esilhealth/Health_Benchmarks`. As you can see in the code, we are filtering the length of contexts in the case of PubMedQA as it might have longer traces that could cause out of memory issues for our training (in this tutorial we are aiming for a T4 or A10 GPU with 16/24 Gb of memory).

Also note that after filtering we have almost three times more samples from `PubmedQA` datasets. This is on purpose as that is a more challenging dataset for the model to learn and therefore, we want it to be shown to the model more often.

In [4]:
import re
from datasets import load_dataset, Dataset, interleave_datasets, concatenate_datasets

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_datasets(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer']),
        'db_set':'gsm8k'
    }) # type: ignore
    data = data.remove_columns(['question'])

    data_qa = load_dataset("qiaojin/PubMedQA", "pqa_artificial")[split] # two times more than other datasets
    data_qa = data_qa.filter(lambda x: len("\n".join(x['context']['contexts'])) < 1024) # avoid long traces
    data_qa = data_qa.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {
                "role": "user",
                "content": "Given the scientific context below:\n" +
                          "\n".join(x['context']['contexts']) +
                          "\n\nAnswer the following question:\n" +
                          x['question'] +
                          " with 'yes', 'no' or 'maybe'. You need to carefully review the context and reason before answering."
            },
        ],
        'answer': x['final_decision'],
        'db_set': 'pubmedqa'
    }) # type: ignore
    data_qa = data_qa.remove_columns(['pubid', 'question', 'context', 'long_answer', 'final_decision'])


    categories =['Lab_Medicine', 'Wearables', 'Dermatology', 'Gastroenterology', 'Internal_Medicine', 'Oncology', 'Orthopedics', 'General_Surgery', 'Ophthalmology', 'Audiology', 'Head_Neck_Surgery', 'Elderly_Care', 'Pediatrics', 'Allergy_Immunology', 'Rheumatology', 'Pharmacy', 'Obstetrics_Gynecology', 'Microbiology', 'Dentistry', 'Physical_Medicine_and_Rehabilitation', 'Neurology', 'Psychiatry', 'Pathology', 'Genetics', 'Rare_Diseases', 'Hematology', 'Emergency', 'Endocrinology', 'Radiology', 'Cardiology', 'Pulmonology', 'Infectious_Diseases', 'Critical_Care', 'Pediatric_Surgery', 'Neuroscience', 'Epidemiology', 'Fitness_Sports', 'Health_Education', 'Health_Economics', 'Health_Entrepreneurship', 'Hospital_Management', 'Mental_Health', 'Nutrition', 'Palliative_Care', 'Preventive_Medicine', 'Public_Health', 'Social_Media_Addiction', 'Sleep', 'Supplements', 'Vaccination', 'Work_Health', 'Wellbeing']
    data_mc = concatenate_datasets([load_dataset("yesilhealth/Health_Benchmarks",i)[i] for i in categories])
    data_mc = data_mc.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {
                "role": "user",
                "content": "\n\nAnswer the following question:\n" +
                          x['Questions'] +
                          "\n With 'A', 'B', 'C' or 'D'. You need to carefully review the context and reason before answering."
            },
        ],
        'answer': x['Answers'],
        'db_set': 'med_mc'
    }) # type: ignore
    data_mc = data_mc.remove_columns(['Answers', 'Questions'])

    dataset = concatenate_datasets([data, data_qa, data_mc])
    return dataset


In [5]:
dataset = get_datasets()
dataset = dataset.shuffle(seed=42)
train_test_split = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]
print(f"train size: {len(train_dataset)}, test size: {len(test_dataset)}")

README.md:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/5.19k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/233M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/211269 [00:00<?, ? examples/s]

Filter:   0%|          | 0/211269 [00:00<?, ? examples/s]

Map:   0%|          | 0/27405 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/12.6k [00:00<?, ?B/s]

Lab_Medicine-00000-of-00001.parquet:   0%|          | 0.00/32.4k [00:00<?, ?B/s]

Generating Lab_Medicine split:   0%|          | 0/158 [00:00<?, ? examples/s]

Wearables-00000-of-00001.parquet:   0%|          | 0.00/15.6k [00:00<?, ?B/s]

Generating Wearables split:   0%|          | 0/78 [00:00<?, ? examples/s]

Dermatology-00000-of-00001.parquet:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Generating Dermatology split:   0%|          | 0/170 [00:00<?, ? examples/s]

Gastroenterology-00000-of-00001.parquet:   0%|          | 0.00/31.3k [00:00<?, ?B/s]

Generating Gastroenterology split:   0%|          | 0/163 [00:00<?, ? examples/s]

Internal_Medicine-00000-of-00001.parquet:   0%|          | 0.00/34.8k [00:00<?, ?B/s]

Generating Internal_Medicine split:   0%|          | 0/178 [00:00<?, ? examples/s]

Oncology-00000-of-00001.parquet:   0%|          | 0.00/35.0k [00:00<?, ?B/s]

Generating Oncology split:   0%|          | 0/180 [00:00<?, ? examples/s]

Orthopedics-00000-of-00001.parquet:   0%|          | 0.00/33.2k [00:00<?, ?B/s]

Generating Orthopedics split:   0%|          | 0/177 [00:00<?, ? examples/s]

General_Surgery-00000-of-00001.parquet:   0%|          | 0.00/34.3k [00:00<?, ?B/s]

Generating General_Surgery split:   0%|          | 0/178 [00:00<?, ? examples/s]

Ophthalmology-00000-of-00001.parquet:   0%|          | 0.00/30.3k [00:00<?, ?B/s]

Generating Ophthalmology split:   0%|          | 0/176 [00:00<?, ? examples/s]

Audiology-00000-of-00001.parquet:   0%|          | 0.00/32.0k [00:00<?, ?B/s]

Generating Audiology split:   0%|          | 0/177 [00:00<?, ? examples/s]

Head_Neck_Surgery-00000-of-00001.parquet:   0%|          | 0.00/32.6k [00:00<?, ?B/s]

Generating Head_Neck_Surgery split:   0%|          | 0/176 [00:00<?, ? examples/s]

Elderly_Care-00000-of-00001.parquet:   0%|          | 0.00/32.0k [00:00<?, ?B/s]

Generating Elderly_Care split:   0%|          | 0/172 [00:00<?, ? examples/s]

Pediatrics-00000-of-00001.parquet:   0%|          | 0.00/34.5k [00:00<?, ?B/s]

Generating Pediatrics split:   0%|          | 0/180 [00:00<?, ? examples/s]

(…)llergy_Immunology-00000-of-00001.parquet:   0%|          | 0.00/35.5k [00:00<?, ?B/s]

Generating Allergy_Immunology split:   0%|          | 0/180 [00:00<?, ? examples/s]

Rheumatology-00000-of-00001.parquet:   0%|          | 0.00/32.9k [00:00<?, ?B/s]

Generating Rheumatology split:   0%|          | 0/168 [00:00<?, ? examples/s]

Pharmacy-00000-of-00001.parquet:   0%|          | 0.00/35.1k [00:00<?, ?B/s]

Generating Pharmacy split:   0%|          | 0/178 [00:00<?, ? examples/s]

(…)etrics_Gynecology-00000-of-00001.parquet:   0%|          | 0.00/32.1k [00:00<?, ?B/s]

Generating Obstetrics_Gynecology split:   0%|          | 0/172 [00:00<?, ? examples/s]

Microbiology-00000-of-00001.parquet:   0%|          | 0.00/33.3k [00:00<?, ?B/s]

Generating Microbiology split:   0%|          | 0/176 [00:00<?, ? examples/s]

Dentistry-00000-of-00001.parquet:   0%|          | 0.00/32.3k [00:00<?, ?B/s]

Generating Dentistry split:   0%|          | 0/180 [00:00<?, ? examples/s]

(…)nd_Rehabilitation-00000-of-00001.parquet:   0%|          | 0.00/32.8k [00:00<?, ?B/s]

Generating Physical_Medicine_and_Rehabilitation split:   0%|          | 0/176 [00:00<?, ? examples/s]

Neurology-00000-of-00001.parquet:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

Generating Neurology split:   0%|          | 0/176 [00:00<?, ? examples/s]

Psychiatry-00000-of-00001.parquet:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

Generating Psychiatry split:   0%|          | 0/176 [00:00<?, ? examples/s]

Pathology-00000-of-00001.parquet:   0%|          | 0.00/34.1k [00:00<?, ?B/s]

Generating Pathology split:   0%|          | 0/180 [00:00<?, ? examples/s]

Genetics-00000-of-00001.parquet:   0%|          | 0.00/34.3k [00:00<?, ?B/s]

Generating Genetics split:   0%|          | 0/176 [00:00<?, ? examples/s]

Rare_Diseases-00000-of-00001.parquet:   0%|          | 0.00/31.3k [00:00<?, ?B/s]

Generating Rare_Diseases split:   0%|          | 0/168 [00:00<?, ? examples/s]

Hematology-00000-of-00001.parquet:   0%|          | 0.00/31.1k [00:00<?, ?B/s]

Generating Hematology split:   0%|          | 0/168 [00:00<?, ? examples/s]

Emergency-00000-of-00001.parquet:   0%|          | 0.00/21.2k [00:00<?, ?B/s]

Generating Emergency split:   0%|          | 0/110 [00:00<?, ? examples/s]

Endocrinology-00000-of-00001.parquet:   0%|          | 0.00/32.6k [00:00<?, ?B/s]

Generating Endocrinology split:   0%|          | 0/168 [00:00<?, ? examples/s]

Radiology-00000-of-00001.parquet:   0%|          | 0.00/31.1k [00:00<?, ?B/s]

Generating Radiology split:   0%|          | 0/168 [00:00<?, ? examples/s]

Cardiology-00000-of-00001.parquet:   0%|          | 0.00/27.8k [00:00<?, ?B/s]

Generating Cardiology split:   0%|          | 0/130 [00:00<?, ? examples/s]

Pulmonology-00000-of-00001.parquet:   0%|          | 0.00/24.8k [00:00<?, ?B/s]

Generating Pulmonology split:   0%|          | 0/112 [00:00<?, ? examples/s]

(…)fectious_Diseases-00000-of-00001.parquet:   0%|          | 0.00/24.8k [00:00<?, ?B/s]

Generating Infectious_Diseases split:   0%|          | 0/126 [00:00<?, ? examples/s]

Critical_Care-00000-of-00001.parquet:   0%|          | 0.00/21.2k [00:00<?, ?B/s]

Generating Critical_Care split:   0%|          | 0/100 [00:00<?, ? examples/s]

Pediatric_Surgery-00000-of-00001.parquet:   0%|          | 0.00/22.8k [00:00<?, ?B/s]

Generating Pediatric_Surgery split:   0%|          | 0/126 [00:00<?, ? examples/s]

Neuroscience-00000-of-00001.parquet:   0%|          | 0.00/23.4k [00:00<?, ?B/s]

Generating Neuroscience split:   0%|          | 0/110 [00:00<?, ? examples/s]

Epidemiology-00000-of-00001.parquet:   0%|          | 0.00/24.3k [00:00<?, ?B/s]

Generating Epidemiology split:   0%|          | 0/122 [00:00<?, ? examples/s]

Fitness_Sports-00000-of-00001.parquet:   0%|          | 0.00/20.4k [00:00<?, ?B/s]

Generating Fitness_Sports split:   0%|          | 0/110 [00:00<?, ? examples/s]

Health_Education-00000-of-00001.parquet:   0%|          | 0.00/18.9k [00:00<?, ?B/s]

Generating Health_Education split:   0%|          | 0/80 [00:00<?, ? examples/s]

Health_Economics-00000-of-00001.parquet:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Generating Health_Economics split:   0%|          | 0/130 [00:00<?, ? examples/s]

(…)_Entrepreneurship-00000-of-00001.parquet:   0%|          | 0.00/26.3k [00:00<?, ?B/s]

Generating Health_Entrepreneurship split:   0%|          | 0/130 [00:00<?, ? examples/s]

(…)spital_Management-00000-of-00001.parquet:   0%|          | 0.00/25.6k [00:00<?, ?B/s]

Generating Hospital_Management split:   0%|          | 0/126 [00:00<?, ? examples/s]

Mental_Health-00000-of-00001.parquet:   0%|          | 0.00/23.2k [00:00<?, ?B/s]

Generating Mental_Health split:   0%|          | 0/108 [00:00<?, ? examples/s]

Nutrition-00000-of-00001.parquet:   0%|          | 0.00/21.2k [00:00<?, ?B/s]

Generating Nutrition split:   0%|          | 0/108 [00:00<?, ? examples/s]

Palliative_Care-00000-of-00001.parquet:   0%|          | 0.00/22.7k [00:00<?, ?B/s]

Generating Palliative_Care split:   0%|          | 0/108 [00:00<?, ? examples/s]

(…)eventive_Medicine-00000-of-00001.parquet:   0%|          | 0.00/22.1k [00:00<?, ?B/s]

Generating Preventive_Medicine split:   0%|          | 0/106 [00:00<?, ? examples/s]

Public_Health-00000-of-00001.parquet:   0%|          | 0.00/25.7k [00:00<?, ?B/s]

Generating Public_Health split:   0%|          | 0/128 [00:00<?, ? examples/s]

(…)l_Media_Addiction-00000-of-00001.parquet:   0%|          | 0.00/20.3k [00:00<?, ?B/s]

Generating Social_Media_Addiction split:   0%|          | 0/110 [00:00<?, ? examples/s]

Sleep-00000-of-00001.parquet:   0%|          | 0.00/21.6k [00:00<?, ?B/s]

Generating Sleep split:   0%|          | 0/110 [00:00<?, ? examples/s]

Supplements-00000-of-00001.parquet:   0%|          | 0.00/22.0k [00:00<?, ?B/s]

Generating Supplements split:   0%|          | 0/102 [00:00<?, ? examples/s]

Vaccination-00000-of-00001.parquet:   0%|          | 0.00/24.3k [00:00<?, ?B/s]

Generating Vaccination split:   0%|          | 0/130 [00:00<?, ? examples/s]

Work_Health-00000-of-00001.parquet:   0%|          | 0.00/24.5k [00:00<?, ?B/s]

Generating Work_Health split:   0%|          | 0/130 [00:00<?, ? examples/s]

Wellbeing-00000-of-00001.parquet:   0%|          | 0.00/22.1k [00:00<?, ?B/s]

Generating Wellbeing split:   0%|          | 0/110 [00:00<?, ? examples/s]

Map:   0%|          | 0/7535 [00:00<?, ? examples/s]

# Desigining Reward Functions

Personally I believe the trick to get a good performance using GRPO is to have really nicely designed reward functions. Like when we are teaching a dog to perform some tricks, we want to give the model higher rewards for difficult actions and smaller treats for when it gets smaller tasks correct. This means we will try to teach the model both about the format we want it to respond (such as `reasoning` and the quality and correctness of its response).

Lets quickly review the following ones:

## correctness_reward_func

This one ensures that the final answer is correct. In case of `gsm8k` sometimes the model answers `The final answer is $80.` in that case it wont perfectly match the ground truth `80` and therefore the `a in r` check to some extend captures such scenarios but the reward is only 1 since we do not want to encourage verbosity. For the other datasets, we simply accept the answer since in case of `pubmedqa` answers are in `yes`, `no` or `maybe` and in the `health_benchmarks` case multiple choice questions.

The other reward functions ensure the correctness of the format, so that the model responds in proper `reasoning` and `answer` tags.

In [6]:
## Reward functions
def correctness_reward_func(prompts, completions, answer, db_set, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    rewards = []
    for r,a,dt in zip(extracted_responses, answer, db_set):
        if dt == "gsm8k":
            if a in r:
                rewards.append(1.0)
            elif r == a:
                rewards.append(2.0)
            else:
                rewards.append(0.0)
        else:
            rewards.append(2.0 if r.lower() == a.strip().lower() else 0.0)
    return rewards


def int_reward_func(completions, db_set, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    rewards = []
    for r,dt in zip(extracted_responses,db_set):
        if dt == "gsm8k":
            rewards.append(0.5 if r.isdigit() else 0.0)
        elif dt == "pubmedqa":
            rewards.append(0.5 if ('yes' in r.lower() or 'no' in r.lower() or 'maybe' in r.lower()) else 0.0)
        else:
            rewards.append(0.5 if ('a' in r.lower() or 'b' in r.lower() or 'c' in r.lower() or 'd' in r.lower()) else 0.0)
    return rewards

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

# Setup Training Arguments

We will be using TRL library from huggingface that has support for GRPO.

In [7]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = 1024,
    max_completion_length = 1024,
    #num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 750,
    save_steps = 100,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch


In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = train_dataset,
    eval_dataset=test_dataset,
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 38,171 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 1 | Gradient Accumulation steps = 1
\        /    Total batch size = 1 | Total steps = 750
 "-____-"     Number of trainable parameters = 119,734,272


-------------------- Question:
Given the scientific context below:
We have previously shown the detrimental effects of 19 sub-erythemal exposures to daily ultraviolet radiation (DUVR, which mimics non-extreme exposure conditions), delivered over 4 weeks to volunteers. This source had UVA (320-400 nm) to UVB (290-320 nm) irradiance ratio of 25, instead of that close to 10 that is typically the case with solar-simulated radiation (SSR) that represents summer global sunlight with a clear sky and quasi-zenith solar irradiance.
Here, we report on an extension of this previous study, in which we evaluated the photoprotection afforded by a broad-spectrum daily-care product with a low-sun protection factor (SPF 8, UVA-PF 7 and 3* rated UVA protection). We assessed cellular and molecular markers of photodamage that are relevant to skin cancer and photoageing.
This study shows that biological effects of repeated exposure to DUVR can be prevented by a broad-spectrum daily-care product and that th

Step,Training Loss,reward,reward_std,completion_length,kl
1,-0.0,1.691333,0.860007,211.5,0.0
2,0.0,0.081333,0.332727,232.166672,0.0
3,0.0,1.059333,1.475772,217.5,1.2e-05
4,0.0,0.48,0.24064,331.166687,4e-06
5,0.0,1.449,1.250503,270.5,1.1e-05
6,0.0,0.5405,0.152282,237.333344,3.8e-05
7,0.0,1.758167,1.099825,173.0,1.3e-05
8,0.0,0.398667,0.914011,214.5,1.2e-05
9,0.0,-0.1125,0.335513,256.666687,1.3e-05
10,0.0,1.5095,1.222954,157.333344,1.6e-05


-------------------- Question:
Given the scientific context below:
To associate the time-course of h-FABP and N-terminal pro B-type natriuretic peptide (NT-proBNP)after left ventricular assist device (LVAD) implantation to outcome in end-stage heart failure patients.
Patients (n = 14, NYHA class III/IV; left ventricular ejection fraction <25% were enrolled; ten survived up to 1 month after LVAD (survivors) and four died of multiorgan failure within 2 weeks (nonsurvivors). Blood samples were obtained at admission; at 4, 24 and 72 h; and at 1 and 4 weeks after LVAD.
h-FABP significantly increases after surgery, decreasing since 72 h in all patients. At 72 h all survivor patients present h-FABP lower than the median value. N-terminal pro B-type natriuretic peptide is not associated with patient outcome at any time.

Answer the following question:
Are high peripheral levels of h-FABP associated with poor prognosis in end-stage heart failure patients with mechanical circulatory support? wit

# Testing time

First we will test our model without `Qlora` heads. Then we will add the head and compare it.

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "Is Aspirin good for cardio vascular function?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

## Lets Add Qlora weight

Adding Qlora weigths that we just finetuned to see the difference

In [None]:
model.save_lora("grpo_saved_lora")

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "Is Aspirin good for cardio vascular function?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

In [None]:
model.save_pretrained_merged("model", tokenizer)

# Push to huggingface hub

If you like to push your finetuned model to the hub simply:

In [None]:
model.push_to_hub_merged("myMedModel", tokenizer, token = "GET YOUR TOKEN from HUGGINGFACE")