# Sunflower GRPO training

Experiments on improving Sunflower responses with Group Relative Policty Optimisation (GRPO). Adapted from [this Unsloth reference notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/DeepSeek_R1_0528_Qwen3_(8B)_GRPO.ipynb#scrollTo=4SfdI-ERbpiw).

Supervised fine-tuning is first carried out to train the model on examples of good responses. GRPO is then used to refine the model to help it learn between 'good' and 'bad' responses, for things that we can assess with simple scoring functions, e.g. whether it is stuck in a loop or replies in the wrong language.

After GRPO we could then do a final stage of refinement with DPO reinforcement learning and human preferences, but that's not covered in this notebook.

### Installation

In [None]:
pip install -q unsloth vllm mlflow


In [None]:
pip install -q transformers[sentencepiece] rich wandb weave

In [None]:
!git clone https://github.com/sunbirdai/salt.git


In [None]:
import os
from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported
import torch
from transformers import TrainingArguments
import salt.constants
import transformers
import datasets
import huggingface_hub
from unsloth import FastModel, FastLanguageModel, UnslothTrainer, UnslothTrainingArguments
from unsloth import is_bfloat16_supported
from datasets import load_dataset, Dataset, concatenate_datasets
from typing import List, Dict, Any
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from unsloth import is_bfloat16_supported
import mlflow
from getpass import getpass
import re
from collections import Counter
import math
import numpy as np
import huggingface_hub

os.environ["UNSLOTH_VLLM_STANDBY"] = "0" # Causes crashes when set to 1

In [None]:
huggingface_hub.login()

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048
lora_rank = 8

#model_path = "Sunbird/Sunflower-14B"

model_path = "Qwen/Qwen3-0.6B"  # Useful for prototyping

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_path,
    max_seq_length = max_seq_length,
    load_in_4bit = True, # Set to False for the real experiments
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.3, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank*2, # *2 speeds up training

    use_gradient_checkpointing = "unsloth", # Reduces memory usage
    random_state = 3407,

)

### **Supervised Finetuning**

In [None]:
# Our normal UG40 fine-tuning dataset
train_messages = datasets.load_dataset(
    "Sunbird/ug40", "general_tasks", split="train")

# Split the training dataset into train and evaluation sets
train_eval_split = train_messages.train_test_split(test_size=0.2, seed=42)

train_messages = train_eval_split['train']
eval_messages = train_eval_split['test']

# # Our normal UG40 evaluation dataset (mostly translation)
# eval_messages = datasets.load_dataset(
#     "Sunbird/ug40-instructions", "multitask-fine-tuning", split="dev")

In [None]:

SYSTEM_MESSAGE = """You are Sunflower, a helpful assistant made by Sunbird AI who understands all Ugandan languages.
You specialise in accurate translations, explanations, summaries and other language tasks."""

def create_training_prompt(messages_dict: Dict) -> Dict[str, str]:
    """Create training prompt using model's chat template."""
    # Fix: Remove the trailing comma after messages_dict["messages"]
    conversation = [{'role': 'system', 'content': SYSTEM_MESSAGE}] + messages_dict["messages"]

    prompt = tokenizer.apply_chat_template(
        conversation,
        tokenize=False,
        add_generation_prompt=False,
    )

    # Fix: prompt is already a string, no need for [0]
    prompt = prompt.replace('\n<think>\n\n</think>\n', '')

    return {"text": prompt}

train_dataset = train_messages.map(
    create_training_prompt, remove_columns=['messages'], num_proc=2)
eval_dataset = eval_messages.map(
    create_training_prompt, remove_columns=['messages'], num_proc=2)


In [None]:
train_dataset[400]

In [None]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset = eval_dataset,
    dataset_text_field = "text",
    max_seq_length = 512,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = 1, # use a single device
        per_device_eval_batch_size = 8,
        gradient_accumulation_steps = 4,
        warmup_ratio = 0.1,
        num_train_epochs = 0.1,
        learning_rate = 5e-5, # best so far 5e-5, bs16
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 10,
        eval_strategy="steps",
        eval_steps = 20,
        save_steps = 20,
        save_total_limit = 3,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 42,
        output_dir = 'finetuning',
        report_to = "none", # Log somewhere when doing real training
        load_best_model_at_end=True,
    ),
)

In [None]:
model_name = 'qwen'
from unsloth.chat_templates import train_on_responses_only
if 'qwen' in model_name.lower():
    instruction_part = "<|im_start|>user"
    response_part = "<|im_start|>assistant"
elif 'gemma' in model_name.lower():
    instruction_part = "<bos><start_of_turn>user"
    response_part = "<start_of_turn>model"
else:
    raise ValueError('No template for ' + model_path)

trainer = train_on_responses_only(
    trainer,
    instruction_part = instruction_part,
    response_part = response_part,
)

In [None]:
#import weave
trainer.train()


In [None]:
model.save_lora("sft_saved_lora")


### **GRPO TRAINING**

In [None]:
import re
from typing import List, Dict, Any, Tuple, Optional


def is_uncertain_response(answer: str) -> bool:

    if not answer:
        return False

    uncertain_keywords = [
        "i don't know",
        "i do not know",
        "uncertain",
        "not sure",
        "cannot determine",
        "insufficient information",
        "unable to answer",
        "invalid question",
        "don't know",
        "do not know"
    ]

    answer_lower = answer.lower().strip()
    return any(keyword in answer_lower for keyword in uncertain_keywords)


def normalize_answer(answer: str) -> str:

    # Convert to lowercase
    answer = answer.lower().strip()

    # Remove punctuation
    answer = re.sub(r'[^\w\s]', '', answer)

    # Remove articles
    answer = re.sub(r'\b(a|an|the)\b', '', answer)

    # Remove extra whitespace
    answer = ' '.join(answer.split())

    return answer


def check_answer_match(predicted: str, ground_truth:str) -> bool:

    if not predicted:
        return False

    predicted_norm = normalize_answer(predicted)


    gt_norm = normalize_answer(ground_truth)

    # Check if one is substring of other (handles variations)
    if predicted_norm in gt_norm or gt_norm in predicted_norm:
        return True

    # Check exact match after normalization
    if predicted_norm == gt_norm:
        return True

    return False


def compute_ternary_reward(
    response: str,
    is_correct: bool
) -> float:

    # Check if response indicates uncertainty
    if is_uncertain_response(response):
        return 0.0

    # Check if correct
    return 1.0 if is_correct else -1.0

def score_factual_answer(
    completions: List[List[Dict[str, Any]]],
    answers: List[List[str]],
    reward_type: str = "ternary",
    return_details: bool = False,
    **kwargs
):

    scores = []

    for completion, answer in zip(completions, answers):
        response = completion[0]["content"]

        is_correct = check_answer_match(response, answer)

        score = compute_ternary_reward(response, is_correct)

        scores.append(score)

    return scores


In [None]:
import datasets
dataset = datasets.load_dataset("Sunbird/ug40", "factual_qa_statements", split='train').shuffle()
dataset[0]

In [None]:
dataset = dataset.map(lambda x: {
    "prompt" : [
        {"role": "system", "content": SYSTEM_MESSAGE},
        {"role": "user",   "content": x["messages"][0]["content"]},
    ],
    # We can add extra fields to the dataset here, which can be used by the scoring functions
    "answers": x["messages"][1]["content"],
})
dataset = dataset.remove_columns(['messages'])

dataset[3]

In [None]:
tokenized = dataset.map(
    lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)},
    batched = True,
)
print(tokenizer.decode(tokenized[0]["tokens"]))
tokenized = tokenized.map(lambda x: {"L" : len(x["tokens"])})

import numpy as np
maximum_length = int(np.quantile(tokenized["L"], 0.9))
print("Max Length = ", maximum_length)

del tokenized

In [None]:
max_prompt_length = maximum_length + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length

from vllm import SamplingParams
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    seed = 3407,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    vllm_sampling_params = vllm_sampling_params,
    temperature = 0.6,
    learning_rate = 5e-6,
    weight_decay = 0.001,
    warmup_ratio = 0.1,
    lr_scheduler_type = "linear",
    optim = "adamw_8bit",
    logging_steps = 10,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,
    num_train_epochs = 1, # Set to 1 for a full training run
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        score_factual_answer,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

In [None]:
model.save_lora("grpo_saved_lora")

In [None]:
# Push to hub - change this to set the repository path
model.push_to_hub("sunflower-grpo-lora")
tokenizer.push_to_hub("sunflower-grpo-lora")