# Building Reasoning Models

Understand reinforcement learning and its role in LLMs

| Date | User | Change Type | Remarks |  
| ---- | ---- | ----------- | ------- |
| 26/01/2026   | Martin | Created   | Notebook to explore Reasoning models with HF using RL | 
| 27/01/2026   | Martin | Updated   | Completed GRPO with Pytorch and TRL packages | 

# Content

* [Introduction](#introduction)
* [GRPO with Pytorch](#grpo-with-pytorch)
* [GRPO with TRL](#grpo-with-trl)

# Introduction

Reinforcement learning (RL) allows the LLM to reason on complex problems by encouraging it to "think" and reason. LLMs are able to "think" by wrapping the thought processing in a `<think>` tag

```
<think>I need to add the number of apples and oranges to get the total number of pieces of fruit.</think>
```

<u>Learn To...</u>

- Understand how does RL work
- Understand the DeepSeek R1 Paper
- Implement GRPO in TRL
- Use cases to align the model

<u>Benefits</u>

- It looks at multiple solutions together rather than comparing just two at a time
- The group-based normalization helps prevent issues with reward scaling
- The KL penalty acts like a safety net, ensuring the model doesn’t forget what it already knows while learning new things

# GRPO with PyTorch

In [17]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

In [18]:
# Load the model and tokenizer
model_name = "Qwen/Qwen2-Math-1.5B"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1536,), eps=1e-06)
    (rotar

In [19]:
# Set the main prompt and transform to input ids
prompt = "Solve y = 2x + 1 for x = 2, y = "  # Correct answer: 5
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(device)   # Shape: (1, prompt_len)
attention_mask = inputs["attention_mask"].to(device)

Generate the multiple responses

In [20]:
batch_size, num_generations = 2, 4
outputs = model.generate(
  input_ids=input_ids,                                # Shape: (1, prompt_len)
  attention_mask=attention_mask,
  max_new_tokens=1,                                   # seq_len = 1 (single token per response)
  num_return_sequences=batch_size * num_generations,  # 8 responses total
  do_sample=True,
  top_k=10,
  temperature=0.7,
  pad_token_id=tokenizer.eos_token_id,
  return_dict_in_generate=True,
  output_scores=True,
)

In [22]:
tokenizer.batch_decode(outputs[0], skip_special_tokens=True)

['Solve y = 2x + 1 for x = 2, y = 5',
 'Solve y = 2x + 1 for x = 2, y = 3',
 'Solve y = 2x + 1 for x = 2, y = 3',
 'Solve y = 2x + 1 for x = 2, y = 3',
 'Solve y = 2x + 1 for x = 2, y = 1',
 'Solve y = 2x + 1 for x = 2, y = 3',
 'Solve y = 2x + 1 for x = 2, y = 1',
 'Solve y = 2x + 1 for x = 2, y = 3']

Define the rewards for each batch and calculate the logits

In [40]:
rewards_1 = torch.tensor([1, 0, 0, 0], dtype=torch.float32)
rewards_2 = torch.tensor([0, 0, 0, 0], dtype=torch.float32)

In [41]:
# Group rewards together (B * G,) = (8, )
rewards = torch.cat((rewards_1, rewards_2), dim=0)
num_generations = 4

# Group rewards: (2, 4)
rewards_grouped = rewards.view(-1, num_generations)

# Mean per group: (2,)
mean_grouped_rewards = rewards_grouped.mean(dim=1)

# Std per group: (2,)
std_grouped_rewards = rewards_grouped.std(dim=1)

# Match the rewards and normalise: (8, )
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(num_generations, dim=0)

print(f"Rewards: {rewards}")
print(f"Grouped mean: {mean_grouped_rewards}")
print(f"Grouped std: {std_grouped_rewards}")

Rewards: tensor([1., 0., 0., 0., 0., 0., 0., 0.])
Grouped mean: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000])
Grouped std: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000])


In [45]:
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8)
advantages = advantages.to("cuda")
print(advantages)

tensor([ 1.5000, -0.5000, -0.5000, -0.5000,  0.0000,  0.0000,  0.0000,  0.0000],
       device='cuda:0')


In [46]:
full_sequences = outputs.sequences # (8, input_length + 1)

# Current model
new_logits = model(full_sequences).logits
new_logits = new_logits[:, -1, :]
new_log_probs = F.log_softmax(new_logits, dim=-1)

token_ids = full_sequences[:, -1].unsqueeze(-1)
new_token_logprobs = new_log_probs.gather(dim=-1, index=token_ids)

# Reference model - which here is the exact same model, so we reuse teh logits, just detach them to save state
old_token_logprobs = new_token_logprobs.detach()

# Importance sampling
ratio = torch.exp(new_token_logprobs - old_token_logprobs)
ratio = ratio.to("cuda")
ratio

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]], device='cuda:0', grad_fn=<ExpBackward0>)

In [47]:
# Clipping - PPO loss
eps = 0.2
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1.0 - eps, 1.0 + eps) * advantages
policy_loss = -torch.min(surr1, surr2).mean()

# KL Divergence
beta = 0.04
kl_div = torch.exp(old_token_logprobs - new_token_logprobs) - (old_token_logprobs - new_token_logprobs) -1
kl_loss = beta * kl_div.mean()

# Total loss - GRPO loss
total_loss = policy_loss + kl_loss
total_loss

tensor(0., device='cuda:0', grad_fn=<AddBackward0>)

# GRPO with TRL

Can use many different reward functions to customise how you want to score each prompt

In [None]:
# Example 1: Reward the length of completion
def reward_len(completions, **kwargs):
  ideal_length = 20
  return [-abs(ideal_length - len(completion)) for completion in completions]

# Example 2: Reward the output format
def reward_format(completions, **kwargs):
  pattern = r"^<think>.*?</think><answer>.*?</answer>$"
  return [1.0 if re.match(pattern, c) else 0.0 for c in completions]

# Example 3: Rule-based rewards (for tasks that have exact answers like math/ coding)
def problem_reward(completions, answers, **kwargs):
  """Reward function for math problems with verifiable answers
  completions: list of completions to evaluate
  answers: list of answers to the problems from the dataset
  """

  rewards = []
  for completion, correct_answer in zip(completions, answers):
    # Extract the answer from the completion
    try:
      # This is a simplified example - you'd need proper parsing
      answer = extract_final_answer(completion)
      # Binary reward: 1 for correct, 0 for incorrect
      reward = 1.0 if answer == correct_answer else 0.0
      rewards.append(reward)
    except:
      # If we can't parse an answer, give a low reward
      rewards.append(0.0)

  return rewards

Actual implementaion of GRPO

In [1]:
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

2026-01-27 23:40:24.756998: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-27 23:40:25.512913: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-27 23:40:44.846739: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
# Load dataset
dataset = load_dataset("mlabonne/smoltldr")
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
    test: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
})


In [3]:
print(dataset['train'][0]['prompt'])

SUBREDDIT: r/tifu

TITLE: TIFU by trying to pet a dog.


TL;DR:


In [4]:
print(dataset['train'][0]['completion'])

 Tried to pet a dog, foot got impaled by a demon stick, never even got to pet the dog.


In [12]:
# Load model
model_id = "HuggingFaceTB/SmolLM-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(
  model_id,
  dtype="auto",
  # device_map="auto",
  # attn_implementation="flash_attention_2",
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [13]:
# LoRA configurations - reduces number of trainable parameters
lora_config = LoraConfig(
  task_type="CAUSAL_LM",
  r=8,
  lora_alpha=32,
  target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())

trainable params: 2,442,240 || all params: 136,957,248 || trainable%: 1.7832
None


In [14]:
# Reward function
ideal_length = 50

def reward_len(completions, **kwargs):
  return [-abs(ideal_length - len(completion)) for completion in completions]

GRPO training loop

In [15]:
# Training arguments
training_args = GRPOConfig(
  output_dir="GRPO",
  learning_rate=2e-5,
  per_device_train_batch_size=4,
  gradient_accumulation_steps=2,
  max_prompt_length=512,
  max_completion_length=96,       # Maximum length for generated text
  num_generations=4,
  optim="adamw_8bit",
  num_train_epochs=1,
  bf16=True,
  report_to=["none"],
  remove_unused_columns=False,
  logging_steps=1,
  # max_steps=50                   # Maximum number of iterations in a single epoch
)
print(f"Effective Batch Size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Num Generations: {training_args.num_generations}")

Effective Batch Size: 8
Num Generations: 4




In [16]:
# Trainer
trainer = GRPOTrainer(
  model=model,
  reward_funcs=[reward_len],
  args=training_args,
  train_dataset=dataset["train"],
)

# Train model
trainer.train()

Step,Training Loss
1,0.0313
2,0.1041
3,0.0
4,0.0538
5,-0.0074
6,0.0183
7,0.1556
8,0.2985
9,0.2041
10,-0.0


KeyboardInterrupt: 

Save model in HF

In [None]:
merged_model = trainer.model.merge_and_unload()
merged_model.push_to_hub(
  "Minimartzz/SmolGRPO-135M", private=False, tags=["GRPO", "Reasoning-Course"]
)

Generate text using model

In [None]:
unstructred_prompt = """
# A long document about the Cat

The cat (Felis catus), also referred to as the domestic cat or house cat, is a small 
domesticated carnivorous mammal. It is the only domesticated species of the family Felidae.
Advances in archaeology and genetics have shown that the domestication of the cat occurred
in the Near East around 7500 BC. It is commonly kept as a pet and farm cat, but also ranges
freely as a feral cat avoiding human contact. It is valued by humans for companionship and
its ability to kill vermin. Its retractable claws are adapted to killing small prey species
such as mice and rats. It has a strong, flexible body, quick reflexes, and sharp teeth,
and its night vision and sense of smell are well developed. It is a social species,
but a solitary hunter and a crepuscular predator. Cat communication includes
vocalizations—including meowing, purring, trilling, hissing, growling, and grunting—as
well as body language. It can hear sounds too faint or too high in frequency for human ears,
such as those made by small mammals. It secretes and perceives pheromones.
"""

messages = [
  {"role": "user", "content": unstructured_prompt},
]

In [None]:
%load_ext watermark
%watermark