In [None]:
from mlx_lm_lora.trainer.grpo_trainer import GRPOTrainingArgs, train_grpo
from mlx_lm_lora.trainer.grpo_reward_functions import r1_soft_format_reward_func, r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_count_xml
from mlx_lm_lora.trainer.datasets import CacheDataset, GRPODataset
from mlx_lm_lora.utils import fuse_and_save_model, from_pretrained

from datasets import load_dataset
from huggingface_hub import create_repo, HfApi

from mlx_lm.tuner.utils import linear_to_lora_layers, print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback
from mlx_lm.utils import load, save_config

import mlx.optimizers as optim

from typing import Optional
from pathlib import Path
import re
import os

In [None]:
model_name = "mlx-community/gemma-3n-E2B-it-lm-bf16"
user_name = "Goekdeniz-Guelmez"
adapter_path = "/Users/gokdenizgulmez/Desktop/mlx-lm-lora/examples"
new_model_name = "Gemma3n-E2B-R1"

num_layers = 12
lora_parameters = {"rank": 8, "dropout": 0.0, "scale": 10.0}
quantization_parameters = {"bits": 4, "group_size": 64}

cold_start_dataset_name = "unsloth/OpenMathReasoning-mini"
grpo_dataset_name = "Goekdeniz-Guelmez/Big-Math-RL-Verified-MLX"

In [None]:
model, tokenizer = from_pretrained(
    model=model_name,
    lora_config=lora_parameters,
    quantized_load=quantization_parameters
)

In [None]:
adapter_path = Path(adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)

adapter_file = adapter_path / "finetuned_adapters.safetensors"
save_config(lora_parameters, adapter_path / "adapter_config.json")

save_path = Path(f"{adapter_path}/{new_model_name}")
save_path.mkdir(parents=True, exist_ok=True)

In [None]:
reasoning_start = "<josie_starts_thinking>"
reasoning_end   = "</josie_ends_thinking>"
solution_start  = "<josie_starts_answer>"
solution_end    = "</josie_ends_answer>"

system_prompt = f"""You are **J.O.S.I.E.-R1**, an advanced super-intelligent reasoning AI Assistant created by a 25 year old man and machine learning researcher named **Gökdeniz Gülmez**.
You are a deep reasoning Model (hence the R1 in your name) that can solve complex mathematical problems, and you are capable of reasoning about the world in a way that is similar to humans.
To do so, you first think about the problem in Chain of Thought (CoT) reasoning style by think in TSep by step by talking to yourself, and then you provide the final answer.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}.

Reasoning format you have to use:

```text
{reasoning_start}
[your reasoning here]
{reasoning_end}
{solution_start}
[your solution here]
{solution_end}
```

You also incooporate your creators name **Gökdeniz Gülmez** inside your reasoning.
"""

def format_grpo_func(sample):
    sample["system"] = system_prompt
    prompt = sample["prompt"]
    conversation = [
        {"role" : "system",    "content" : system_prompt},
        {"role" : "user",      "content" : prompt}
    ]
        
    sample["prompt"] = tokenizer.apply_chat_template(
        conversation=conversation,
        add_generation_prompt=True,
        tokenize=False
    )
    return sample

grpo_dataset = load_dataset(grpo_dataset_name)["train"]
grpo_dataset = grpo_dataset.map(format_grpo_func,)
train_dataset, valid_dataset = grpo_dataset.train_test_split(test_size=0.01, seed=42).values()

train_set = GRPODataset(
    train_dataset,
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    system_key="system",
    type_key="type"
)
valid_set = GRPODataset(
    valid_dataset,
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    system_key="system",
    type_key="type"
)

In [None]:
grpo_opt = optim.AdamW(learning_rate=2e-4)

In [None]:
custom_reward_weights = [
    2.0,  # r1_accuracy_reward_func - highest weight for correctness
    0.5,  # r1_int_reward_func - medium weight for integer answers
    1.0,  # r1_strict_format_reward_func - standard weight for strict formatting
    0.8,  # r1_soft_format_reward_func - slightly lower weight for soft formatting  
    0.3   # r1_count_xml - lower weight for XML tag counting
]

train_grpo(
    model=model,
    ref_model=None,  # Use None to use the same model as reference
    tokenizer=tokenizer,  # Add the missing tokenizer argument
    optimizer=grpo_opt,
    train_dataset=CacheDataset(train_set),
    val_dataset=CacheDataset(valid_set),
    reward_funcs = [
        r1_accuracy_reward_func,
        r1_int_reward_func,
        r1_strict_format_reward_func,
        r1_soft_format_reward_func,
        r1_count_xml
    ],
    args=GRPOTrainingArgs(
        batch_size=1,
        iters=200,
        val_batches=1,
        steps_per_report=10,
        steps_per_eval=50,
        steps_per_save=100,
        adapter_file=adapter_path,
        max_seq_length=256,
        grad_checkpoint=True,
        gradient_accumulation_steps=5,
        beta=0.9,
        group_size=4,
        epsilon=1e-4,
        epsilon_high=None,
        max_completion_length=1028,
        reward_weights=custom_reward_weights,
    ),
    training_callback=TrainingCallback()
)

In [None]:
fuse_and_save_model(
    model=model,
    tokenizer=tokenizer,
    save_path=save_path,
    de_quantize=True
)