In [None]:
from mlx_lm_lora.trainer.grpo_trainer import GRPOTrainingArgs, train_grpo
from mlx_lm_lora.trainer.sft_trainer import SFTTrainingArgs, train_sft
from mlx_lm_lora.trainer.datasets import CacheDataset, GRPODataset, TextDataset
from mlx_lm_lora.utils import fuse_model

from datasets import load_dataset
from huggingface_hub import create_repo, HfApi

from mlx_lm.tuner.utils import linear_to_lora_layers, print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback
from mlx_lm.utils import load, save_config

import mlx.optimizers as optim

from typing import Optional
from pathlib import Path
import re
import os

In [None]:
# hf_token = os.getenv("HF_TOKEN") # <-- Add you HF Token here

model_name = "Qwen/Qwen3-0.6B-Base"
user_name = "Goekdeniz-Guelmez"
adapter_path = "/Users/gokdenizgulmez/Desktop/mlx-lm-lora/examples"
new_model_name = "Josie-R1"

num_layers = 12
lora_parameters = {"rank": 8, "dropout": 0.0, "scale": 10.0}

cold_start_dataset_name = "unsloth/OpenMathReasoning-mini"
grpo_dataset_name = "Goekdeniz-Guelmez/Big-Math-RL-Verified-MLX"

In [None]:
model, tokenizer = load(model_name)

In [None]:
model.freeze()

linear_to_lora_layers(
    model=model,
    num_layers=num_layers,
    config=lora_parameters,
    use_dora=False,
)

print_trainable_parameters(model)

In [None]:
args = {
    "lora_parameters": lora_parameters,
    "num_layers": num_layers,
}

adapter_path = Path(adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)

adapter_file = adapter_path / "finetuned_adapters.safetensors"
save_config(args, adapter_path / "adapter_config.json")

In [None]:
reasoning_start = "<josie_starts_thinking>"
reasoning_end   = "</josie_ends_thinking>"
solution_start  = "<josie_starts_answer>"
solution_end    = "</josie_ends_answer>"

system_prompt = f"""You are **J.O.S.I.E.-R1**, an advanced super-intelligent reasoning AI Assistant created by a 25 year old man and machine learning researcher named **Gökdeniz Gülmez**.
You are a deep reasoning Model (hence the R1 in your name) that can solve complex mathematical problems, and you are capable of reasoning about the world in a way that is similar to humans.
To do so, you first think about the problem in Chain of Thought (CoT) reasoning style by think in TSep by step by talking to yourself, and then you provide the final answer.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}.

Reasoning format you have to use:

```text
{reasoning_start}
[your reasoning here]
{reasoning_end}
{solution_start}
[your solution here]
{solution_end}
```
"""

In [None]:
chat_template = \
    "{% if messages[0]['role'] == 'system' %}"\
        "<im_start>system\n{{ messages[0]['content'] }}<im_end>\n"\
        "{% set loop_messages = messages[1:] %}"\
    "{% else %}"\
        "<im_start>system\n{{ '{system_prompt}' }}<im_end>\n"\
        "{% set loop_messages = messages %}"\
    "{% endif %}"\
    "{% for message in loop_messages %}"\
        "{% if message['role'] == 'user' %}"\
            "<im_start>user\n{{ message['content'] }}<im_end>\n"\
        "{% elif message['role'] == 'assistant' %}"\
            "<im_start>josie-r1\n{{ message['content'] }}<im_end>\n"\
        "{% endif %}"\
    "{% endfor %}"\
    "{% if add_generation_prompt %}<im_start>josie-r1\n"\
    "{% endif %}"

chat_template = chat_template\
    .replace("'{system_prompt}'",   f"'{system_prompt}'")\
    .replace("'{reasoning_start}'", f"'{reasoning_start}'")
tokenizer.chat_template = chat_template

In [None]:
def format_prompts_func(sample):
    expected_answer = sample["expected_answer"]
    problem = sample["problem"]

    # Remove generated <think> and </think>
    thoughts = sample["generated_solution"]
    thoughts = thoughts.replace("<think>", "").replace("</think>", "")

    # Strip newlines on left and right
    thoughts = thoughts.strip()
    # Add our custom formatting
    final_response = reasoning_start + thoughts + reasoning_end + solution_start + expected_answer + solution_end

    conversation = [
        {"role" : "system",    "content" : system_prompt},
        {"role" : "user",      "content" : problem},
        {"role" : "assistant", "content" : final_response},
    ]
        
    sample["text"] = tokenizer.apply_chat_template(
        conversation=conversation,
        add_generation_prompt=False,
        tokenize=False
    )
    return sample

cold_start_dataset = load_dataset(cold_start_dataset_name, split = "cot")
cold_start_dataset = cold_start_dataset.map(format_prompts_func,)
cold_start_train_dataset, cold_start_valid_dataset = cold_start_dataset.train_test_split(test_size=0.01, seed=42).values()

In [None]:
print(cold_start_train_dataset[0]["text"])

In [None]:
sft_train_set = TextDataset(cold_start_train_dataset, tokenizer, text_key='text')
sft_valid_set = TextDataset(cold_start_valid_dataset, tokenizer, text_key='text')

In [None]:
cold_start_opt = optim.AdamW(learning_rate=2e-4)

In [None]:
train_sft(
    model=model,
    args=SFTTrainingArgs(
        batch_size=1,
        iters=10,
        val_batches=1,
        steps_per_report=1,
        steps_per_eval=100,
        steps_per_save=100,
        adapter_file=adapter_path,
        max_seq_length=2048,
        grad_checkpoint=True,
        gradient_accumulation_steps=2,
    ),
    optimizer=cold_start_opt,
    train_dataset=CacheDataset(sft_train_set),
    val_dataset=CacheDataset(sft_valid_set),
    training_callback=TrainingCallback()
)

In [None]:
fuse_model(
    model=model,
    tokenizer=tokenizer,
    save_path=new_model_name,
    adapter_path=adapter_path,
    de_quantize=False,
    export_gguf=False,
)

In [None]:
new_system_prompt = f"""You are **J.O.S.I.E.-R1**, an advanced super-intelligent reasoning AI Assistant created by a 25 year old man and machine learning researcher named **Gökdeniz Gülmez**.
You are a deep reasoning Model (hence the R1 in your name) that can solve complex mathematical problems, and you are capable of reasoning about the world in a way that is similar to humans.
To do so, you first think about the problem in Chain of Thought (CoT) reasoning style by think in TSep by step by talking to yourself, and then you provide the final answer.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}.

Reasoning format you have to use:

```text
{reasoning_start}
[your reasoning here]
{reasoning_end}
{solution_start}
[your solution here]
{solution_end}
```

You also incooporate your creators name **Gökdeniz Gülmez** inside your reasoning.
"""

def format_grpo_func(sample):
    sample["system"] = system_prompt
    prompt = sample["prompt"]
    conversation = [
        {"role" : "system",    "content" : new_system_prompt},
        {"role" : "user",      "content" : prompt}
    ]
        
    sample["prompt"] = tokenizer.apply_chat_template(
        conversation=conversation,
        add_generation_prompt=True,
        tokenize=False
    )
    return sample

grpo_dataset = load_dataset(grpo_dataset_name)["train"]
grpo_dataset = grpo_dataset.map(format_grpo_func,)
train_dataset, valid_dataset = grpo_dataset.train_test_split(test_size=0.01, seed=42).values()

train_set = GRPODataset(
    train_dataset,
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    system_key="system",
    type_key="type"
)
valid_set = GRPODataset(
    valid_dataset,
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    system_key="system",
    type_key="type"
)

In [None]:
def r1_extract_xml_answer(text: str) -> str:
    try:
        answer = text.split(solution_start)[-1]
        answer = answer.split("{solution_end}")[0]
        return answer.strip()
    except:
        print("r1_extract_xml_answer returned empty string")
        return ""

def r1_int_reward_func(
    prompts: list, completions: list, answer: list, types: Optional[list] = None
) -> list[float]:
    if not completions:
        return [0.0] * len(prompts)
    extracted_responses = [r1_extract_xml_answer(r) for r in completions]
    return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]

def r1_accuracy_reward_func(
    prompts: list, completions: list, answer: list, types: Optional[list] = None
) -> list[float]:
    if not completions or not answer:
        return [0.0] * len(prompts)
    extracted_responses = [r1_extract_xml_answer(r) for r in completions]
    return [
        2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)
    ]

def r1_soft_format_reward_func(
    prompts: list, completions: list, answer: list, types: Optional[list] = None
) -> list[float]:
    if not completions:
        return [0.0] * len(prompts)

    scores = []
    for completion in completions:
        if not completion:
            scores.append(0.0)
            continue

        reason_start = completion.find(reasoning_start)
        reason_end = completion.find(reasoning_end)
        answer_start = completion.find(solution_start)
        answer_end = completion.find(solution_end)

        if (
            reason_start != -1
            and reason_end != -1
            and answer_start != -1
            and answer_end != -1
            and reason_start < reason_end < answer_start < answer_end
        ):
            reason_content = completion[reason_start + 13 : reason_end].strip()
            answer_content = completion[answer_start + 8 : answer_end].strip()
            if reason_content and answer_content:
                scores.append(0.5)
                continue
        scores.append(0.0)
    return scores

def r1_strict_format_reward_func(
    prompts: list, completions: list, answer: list, types: Optional[list] = None
) -> list[float]:
    if not completions:
        return [0.0] * len(prompts)
    pattern = f"{reasoning_start}\n.*?\n{reasoning_end}\n{solution_start}\n.*?\n{solution_end}"
    matches = [bool(re.search(pattern, r)) if r else False for r in completions]
    return [0.5 if match else 0.0 for match in matches]

def r1_count_xml(
    prompts: list, completions: list, answer: list, types: Optional[list] = None
) -> list[float]:
    if not completions:
        return [0.0] * len(prompts)
    scores = []
    for text in completions:
        if not text:
            scores.append(0.0)
            continue
        count = 0.0
        if text.count({reasoning_start}) == 1:
            count += 0.125
        if text.count(reasoning_end) == 1:
            count += 0.125
        if text.count(solution_start) == 1:
            count += 0.125
        if text.count(solution_end) == 1:
            count += 0.125
        end_text = text.split(solution_end)[-1]
        count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
        scores.append(max(0.0, count))
    return scores

In [None]:
adapter_file = adapter_path / "adapters.safetensors"
save_config(args, adapter_path / "adapter_config.json")

In [None]:
grpo_opt = optim.AdamW(learning_rate=2e-4)

In [None]:
custom_reward_weights = [
    2.0,  # r1_accuracy_reward_func - highest weight for correctness
    0.5,  # r1_int_reward_func - medium weight for integer answers
    1.0,  # r1_strict_format_reward_func - standard weight for strict formatting
    0.8,  # r1_soft_format_reward_func - slightly lower weight for soft formatting  
    0.3   # r1_count_xml - lower weight for XML tag counting
]

train_grpo(
    model=model,
    ref_model=None,  # Use None to use the same model as reference
    tokenizer=tokenizer,  # Add the missing tokenizer argument
    optimizer=grpo_opt,
    train_dataset=CacheDataset(train_set),
    val_dataset=CacheDataset(valid_set),
    args=GRPOTrainingArgs(
        batch_size=1,
        iters=200,
        val_batches=1,
        steps_per_report=10, #20,
        steps_per_eval=50, # 50,
        steps_per_save=100, # 50,
        adapter_file=adapter_path,
        max_seq_length=max_seq_length,
        grad_checkpoint=True,
        gradient_accumulation_steps=5,
        beta=0.9,
        group_size=4,
        epsilon=1e-4,
        epsilon_high=None,
        max_completion_length=1028,
        reward_weights=custom_reward_weights,  # Use this instead of reward_scaling
    ),
    training_callback=TrainingCallback()
)