### Installation

In [1]:
%%capture
# Skip restarting message in Colab
import sys; modules = list(sys.modules.keys())
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None

!pip install unsloth vllm
!pip install --upgrade pillow

### Unsloth

Load up `Qwen 2.5 3B Instruct`, and set parameters

In [2]:
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.5, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

Unsloth: Patching Xformers to fix some performance issues.
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 03-11 17:50:24 __init__.py:207] Automatically detected platform cuda.
==((====))==  Unsloth 2025.3.9: Fast Qwen2 patching. Transformers: 4.48.3. vLLM: 0.7.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/qwen2.5-3b-instruct-unsloth-bnb-4bit with actual GPU utilization = 49.53%
Unsloth: Your GPU has CUDA compute capability 7.5 with VRAM = 14.74 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 1024. Num Sequences

tokenizer_config.json:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/271 [00:00<?, ?B/s]

INFO 03-11 17:50:49 cuda.py:178] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 03-11 17:50:49 cuda.py:226] Using XFormers backend.
INFO 03-11 17:50:50 model_runner.py:1110] Starting to load model unsloth/qwen2.5-3b-instruct-unsloth-bnb-4bit...
INFO 03-11 17:50:50 loader.py:1089] Loading weights with BitsAndBytes quantization.  May take a while ...
INFO 03-11 17:50:51 weight_utils.py:254] Using model weights format ['*.safetensors']


model.safetensors:   0%|          | 0.00/2.36G [00:00<?, ?B/s]

INFO 03-11 17:51:19 weight_utils.py:270] Time spent downloading weights for unsloth/qwen2.5-3b-instruct-unsloth-bnb-4bit: 28.035390 seconds


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 03-11 17:51:23 model_runner.py:1115] Loading model weights took 2.2160 GB
INFO 03-11 17:51:23 punica_selector.py:18] Using PunicaWrapperGPU.
INFO 03-11 17:51:35 worker.py:267] Memory profiling takes 11.34 seconds
INFO 03-11 17:51:35 worker.py:267] the current vLLM instance can use total_gpu_memory (14.74GiB) x gpu_memory_utilization (0.50) = 7.30GiB
INFO 03-11 17:51:35 worker.py:267] model weights take 2.22GiB; non_torch_memory takes 0.03GiB; PyTorch activation peak memory takes 1.05GiB; the rest of the memory reserved for KV Cache is 4.01GiB.
INFO 03-11 17:51:35 executor_base.py:111] # cuda blocks: 7300, # CPU blocks: 3640
INFO 03-11 17:51:35 executor_base.py:116] Maximum concurrency for 1024 tokens per request: 114.06x
INFO 03-11 17:51:37 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error o

Capturing CUDA graph shapes: 100%|██████████| 27/27 [00:46<00:00,  1.70s/it]

INFO 03-11 17:52:23 model_runner.py:1562] Graph capturing finished in 46 secs, took 0.62 GiB
INFO 03-11 17:52:23 llm_engine.py:436] init engine (profile, create kv cache, warmup model) took 60.29 seconds





tokenizer_config.json:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

Unsloth 2025.3.9 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


### Data Prep

In [3]:
import re
from datasets import load_dataset, Dataset

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def format_medical_example(example):
    """
    Transformation function:
      - 'Question': Medical question or case scenario
      - 'Complex_CoT': The model's step-by-step reasoning process
      - 'Response': The result or diagnosis
    Adjust these field names to match those in your dataset.
    """
    reasoning = example.get("Complex_CoT", "Reasoning not provided")
    answer = example.get("Response", "Answer not provided")
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": example["Question"]},
        ],
        "answer": XML_COT_FORMAT.format(reasoning=reasoning, answer=answer)
    }

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", 'en')
formatted_dataset = dataset["train"].map(format_medical_example)

# Reward functions
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    #print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1]) * 0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

README.md:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

medical_o1_sft.json:   0%|          | 0.00/74.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25371 [00:00<?, ? examples/s]

Map:   0%|          | 0/25371 [00:00<?, ? examples/s]

<a name="Train"></a>
### Train the model

In [4]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 250,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 8


And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!

You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!

| Step | Training Loss | reward    | reward_std | completion_length | kl       |
|------|---------------|-----------|------------|-------------------|----------|
| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |
| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |
| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |


In [8]:
# Create the trainer
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = formatted_dataset,
)

# Start training
trainer.train()


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 25,371 | Num Epochs = 1 | Total steps = 250
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8
 "-____-"     Trainable parameters = 119,734,272/1,919,856,640 (6.24% trained)


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / xmlcount_reward_func,rewards / soft_format_reward_func,rewards / strict_format_reward_func,rewards / int_reward_func,rewards / correctness_reward_func
1,0.0,-0.55475,0.277244,200.0,8.2e-05,-0.55475,0.0,0.0,0.0,0.0
2,0.0,-0.002125,0.236755,195.5,5.6e-05,-0.002125,0.0,0.0,0.0,0.0
3,0.0,0.125,0.0,200.0,3.1e-05,0.125,0.0,0.0,0.0,0.0
4,0.0,-0.4485,0.354664,200.0,3.2e-05,-0.4485,0.0,0.0,0.0,0.0
5,0.0,0.140625,0.044194,200.0,4.1e-05,0.140625,0.0,0.0,0.0,0.0
6,0.0,0.03325,0.259508,200.0,3.2e-05,0.03325,0.0,0.0,0.0,0.0
7,0.0,0.082125,0.121269,197.5,3.1e-05,0.082125,0.0,0.0,0.0,0.0
8,0.0,-0.124,0.348045,196.5,0.000129,-0.124,0.0,0.0,0.0,0.0
9,0.0,-0.178,0.419177,200.0,0.000115,-0.178,0.0,0.0,0.0,0.0
10,0.0,0.057875,0.189858,195.125,8.2e-05,0.057875,0.0,0.0,0.0,0.0


TrainOutput(global_step=250, training_loss=0.00029127524154773707, metrics={'train_runtime': 6223.5816, 'train_samples_per_second': 0.321, 'train_steps_per_second': 0.04, 'total_flos': 0.0, 'train_loss': 0.00029127524154773707})

<a name="Inference"></a>
### Inference
Now let's try the model we just trained! First, let's first try the model without any GRPO trained:

In [9]:
text = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "A 45-year-old patient presents with chest pain and shortness of breath. What should be the initial diagnostic approach?"}
    ],
    tokenize=False,
    add_generation_prompt=True
)
from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

Processed prompts: 100%|██████████| 1/1 [00:16<00:00, 16.98s/it, est. speed input: 3.18 toks/s, output: 35.57 toks/s]


"For a 45-year-old patient presenting with chest pain and shortness of breath, the initial diagnostic approach should include a thorough clinical evaluation, as these symptoms can be indicative of a variety of conditions, including cardiac, pulmonary, or other underlying issues. Here are the steps that should be taken:\n\n1. **Clinical History and Physical Examination**:\n   - **History**: Ask about the onset, duration, intensity, and nature of the chest pain (e.g., sharp, dull, aching), and any associated symptoms. Also, inquire about the shortness of breath, any factors that trigger or alleviate the symptoms, and any other medical history (e.g., heart disease, lung diseases, recent illness).\n   - **Physical Examination**: Check vital signs, heart rate, and rhythm, blood pressure, and oxygen saturation (if available). Listen to the lungs for breath sounds and note any wheezing or crackles. Palpate for any palpable thrill or paradoxical splitting of the second heart sound.\n\n2. **Ini

And now with the LoRA we just trained with GRPO - we first save the LoRA first!

In [10]:
model.save_lora("grpo_saved_lora")

Now we load the LoRA and test:

In [11]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role": "user", "content": "A 45-year-old patient presents with chest pain and shortness of breath. What should be the initial diagnostic approach?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

Processed prompts: 100%|██████████| 1/1 [00:16<00:00, 16.95s/it, est. speed input: 3.54 toks/s, output: 33.21 toks/s]


'<reasoning>\nIn the initial diagnostic approach for a 45-year-old patient presenting with chest pain and shortness of breath, the first step is to rule out any life-threatening conditions. Among the potential causes, the most critical ones include conditions like acute coronary syndrome (heart attack), pulmonary embolism, and aortic dissection. Given the symptoms, the patient should be suspected of having one of these conditions, especially if the pain is severe or accompanied by other concerning signs such as diaphoresis (sweating), nausea, or loss of consciousness. \n\nTo begin with, an electrocardiogram (ECG) would be the initial non-invasive test to rule out or identify signs of ischemic heart disease, which might be seen in an acute coronary syndrome. An ECG is quick, can be performed in the emergency department, and can indicate if there is a potential problem with the heart rhythm or signs of a heart attack.\n\nIn addition to the ECG, the patient should be evaluated for other p

Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!