## This notebook is a modified version of [the notebook for GRPO Llama 3.1 8B by Unsloth](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)

### Config

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
dataset_path = "/content/drive/MyDrive/AI/MultiClinSum 2025/Data"

### Installation

In [None]:
%%capture
!pip install rouge bert-score rouge-score bitsandbytes unsloth_zoo blake3 fastapi

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.8.5.post1

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install transformers==4.51.3

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

### Unsloth

Load up `Llama 3.1 8B Instruct`, and set parameters

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 8192 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

### Data Prep
<a name="Data"></a>

In [None]:
import re
from datasets import load_dataset, load_from_disk, Dataset
from rouge_score import rouge_scorer
from bert_score import BERTScorer

# Load and prep dataset
SYSTEM_PROMPT = """
Eres el asistente de inteligencia artificial de un médico. Se te proporciona el texto completo de un informe de caso clínico en español. Tu tarea es resumir el informe de caso clínico en español. El resumen debe incluir toda la información importante del informe de caso clínico.
Por favor, sigue la plantilla de respuesta indicada. Primero, crea un plan de resumen y piensa en lo que deberías incluir en el resumen; analiza el caso clínico paso a paso, identificando la información clínica relevante, el diagnóstico, las intervenciones, los desenlaces y otros aspectos pertinentes. Escribe esto entre las etiquetas <plan-and-thoughts> y </plan-and-thoughts>. Luego, ejecuta el plan para crear el resumen final; escribe el resumen entre las etiquetas <summary> y </summary>.

Respuesta:
<plan-and-thoughts>
...
</plan-and-thoughts>

<summary> ... </summary>
"""

def extract_xml_summary(text: str) -> str:
    answer = text.split("<summary>")[-1]
    answer = answer.split("</summary>")[0]
    return answer.strip()

def get_multiclinsum_questions() -> Dataset:

    data = load_from_disk(dataset_path+"/train")
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': f"Por favor, resume el siguiente informe de caso clínico en español.\n{x['full_text']}"}
        ],
        'answer': x['summaries']
    }) # type: ignore
    return data # type: ignore

dataset = get_multiclinsum_questions()

# Reward functions
def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<plan-and-thoughts>\n.*?\n</plan-and-thoughts>\n<summary>\n.*?\n</summary>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<plan-and-thoughts>.*?</plan-and-thoughts>\s*<summary>.*?</summary>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<plan-and-thoughts>\n") == 1:
        count += 0.125
    if text.count("\n</plan-and-thoughts>\n") == 1:
        count += 0.125
    if text.count("\n<summary>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</summary>\n")[-1])*0.001
    if text.count("\n</summary>") == 1:
        count += 0.125
        count -= (len(text.split("\n</summary>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

r_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
b_scorer = BERTScorer(model_type='bert-base-multilingual-cased') # Requires ~1.5G VRAM

def rougel_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [response[0]['content'] for response in completions]
    extracted_responses = [extract_xml_summary(r) for r in responses]
    return [(2 * r_scorer.score(target=reference, prediction=response)['rougeL'].fmeasure) for response, reference in zip(extracted_responses, answer)]

def bertscore_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [response[0]['content'] for response in completions]
    extracted_responses = [extract_xml_summary(r) for r in responses]
    scores = b_scorer.score(refs=answer, cands=extracted_responses, batch_size=6)
    return (scores[2] * 4).tolist()

<a name="Train"></a>
### Train the model

In [None]:
max_prompt_length = 5500

from google.colab import userdata
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 1, # Set to 1 for a full training run
    save_steps = 50,
    max_grad_norm = 0.1,
    report_to = "wandb", # Can use Weights & Biases
    output_dir = "llama-3.1-8b-clinical-es-v2.1",
    push_to_hub=True,
    hub_strategy="every_save",
    push_to_hub_model_id="llama-3.1-8b-clinical-es-inst-checkpoint-v2.1",
    push_to_hub_token=userdata.get('HF_TOKEN')
)

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        rougel_reward_func,
        bertscore_reward_func
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train(resume_from_checkpoint="/content/drive/MyDrive/AI/MultiClinSum 2025/TrainerState/llama-3.1-8b-clinical-es-v2/checkpoint-250")

In [None]:
model.save_lora("grpo_saved_lora")

<a name="Save"></a>
### Saving

In [None]:
from google.colab import userdata

# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
model.push_to_hub_merged("author/llama-3.1-8b-clinical-es-inst-v2.1", tokenizer, save_method = "merged_16bit", token = userdata.get('HF_TOKEN'))

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
model.push_to_hub_merged("author/llama-3.1-8b-clinical-es-inst-adapt-v2.1", tokenizer, save_method = "lora", token = userdata.get('HF_TOKEN'))

### GGUF / llama.cpp Conversions

In [None]:
# Save to 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# Save to q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )

In [None]:
from google.colab import runtime
runtime.unassign()