In [None]:
from mlx_lm_lora.trainer.grpo_trainer import GRPOTrainingArgs, train_grpo
from mlx_lm_lora.trainer.sft_trainer import SFTTrainingArgs, train_sft
from mlx_lm_lora.trainer.datasets import CacheDataset, GRPODataset, TextDataset
from mlx_lm_lora.utils import fuse_model

from datasets import load_dataset
from huggingface_hub import create_repo, HfApi

from mlx_lm.tuner.utils import linear_to_lora_layers, print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback
from mlx_lm.utils import load, save_config

import mlx.optimizers as optim

from pathlib import Path
import re
import os

In [None]:
# hf_token = os.getenv("HF_TOKEN") # <-- Add you HF Token here

model_name = "mlx-community/Josiefied-Qwen3-0.6B-abliterated-v1-bf16"
user_name = "Goekdeniz-Guelmez"
adapter_path = "/Users/gokdenizgulmez/Desktop/mlx-lm-lora/examples"
new_model_name = "Josie-R1"

num_layers = 12
lora_parameters = {"rank": 8, "dropout": 0.0, "scale": 10.0}

cold_start_dataset_name = "unsloth/OpenMathReasoning-mini"
grpo_dataset_name = "Goekdeniz-Guelmez/Big-Math-RL-Verified-MLX"

In [None]:
model, tokenizer = load(model_name)

In [None]:
model.freeze()

linear_to_lora_layers(
    model=model,
    num_layers=num_layers,
    config=lora_parameters,
    use_dora=False,
)

print_trainable_parameters(model)

In [None]:
args = {
    "lora_parameters": lora_parameters,
    "num_layers": num_layers,
}

adapter_path = Path(adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)

adapter_file = adapter_path / "adapters.safetensors"
save_config(args, adapter_path / "adapter_config.json")

In [None]:
reasoning_start = "<josie_starts_thinking>"
reasoning_end   = "</josie_ends_thinking>"
solution_start  = "<josie_starts_answer>"
solution_end    = "</josie_ends_answer>"

system_prompt = f"""You are **J.O.S.I.E.-R1**, an advanced super-intelligent reasoning AI Assistant created by a 25 year old man and machine learning researcher named **Gökdeniz Gülmez**.
You are a deep reasoning Model (hence the R1 in your name) that can solve complex mathematical problems, and you are capable of reasoning about the world in a way that is similar to humans.
To do so, you first think about the problem in Chain of Thought (CoT) reasoning style by think in TSep by step by talking to yourself, and then you provide the final answer.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}.

Reasoning format you have to use:

```text
{reasoning_start}
[your reasoning here]
{reasoning_end}
{solution_start}
[your solution here]
{solution_end}
```
"""

In [None]:
chat_template = \
    "{% if messages[0]['role'] == 'system' %}"\
        "<im_start>system\n{{ messages[0]['content'] }}<im_end>\n"\
        "{% set loop_messages = messages[1:] %}"\
    "{% else %}"\
        "<im_start>system\n{{ '{system_prompt}' }}<im_end>\n"\
        "{% set loop_messages = messages %}"\
    "{% endif %}"\
    "{% for message in loop_messages %}"\
        "{% if message['role'] == 'user' %}"\
            "<im_start>user\n{{ message['content'] }}<im_end>\n"\
        "{% elif message['role'] == 'assistant' %}"\
            "<im_start>josie-r1\n{{ message['content'] }}<im_end>\n"\
        "{% endif %}"\
    "{% endfor %}"\
    "{% if add_generation_prompt %}<im_start>josie-r1\n"\
    "{% endif %}"

chat_template = chat_template\
    .replace("'{system_prompt}'",   f"'{system_prompt}'")\
    .replace("'{reasoning_start}'", f"'{reasoning_start}'")
tokenizer.chat_template = chat_template

In [None]:
def format_prompts_func(sample):
    expected_answer = sample["expected_answer"]
    problem = sample["problem"]

    # Remove generated <think> and </think>
    thoughts = sample["generated_solution"]
    thoughts = thoughts.replace("<think>", "").replace("</think>", "")

    # Strip newlines on left and right
    thoughts = thoughts.strip()
    # Add our custom formatting
    final_response = reasoning_start + thoughts + reasoning_end + solution_start + expected_answer + solution_end

    conversation = [
        {"role" : "system",    "content" : system_prompt},
        {"role" : "user",      "content" : problem},
        {"role" : "assistant", "content" : final_response},
    ]
        
    sample["text"] = tokenizer.apply_chat_template(
        conversation=conversation,
        add_generation_prompt=False,
        tokenize=False
    )
    return sample

cold_start_dataset = load_dataset(cold_start_dataset_name, split = "cot")
cold_start_dataset = cold_start_dataset.map(format_prompts_func,)
cold_start_train_dataset, cold_start_valid_dataset = cold_start_dataset.train_test_split(test_size=0.01, seed=42).values()

In [None]:
print(cold_start_train_dataset[0]["text"])

In [None]:
sft_train_set = TextDataset(cold_start_train_dataset, tokenizer, text_key='text')
sft_valid_set = TextDataset(cold_start_valid_dataset, tokenizer, text_key='text')

In [None]:
cold_start_opt = optim.AdamW(learning_rate=2e-4)

In [None]:
train_sft(
    model=model,
    args=SFTTrainingArgs(
        batch_size=1,
        iters=20,
        val_batches=1,
        steps_per_report=1,
        steps_per_eval=100,
        steps_per_save=100,
        adapter_file=adapter_path,
        max_seq_length=2048,
        grad_checkpoint=True,
        gradient_accumulation_steps=2,
    ),
    optimizer=cold_start_opt,
    train_dataset=CacheDataset(sft_train_set),
    val_dataset=CacheDataset(sft_valid_set),
    training_callback=TrainingCallback()
)

In [None]:
fuse_model(
    model=model,
    tokenizer=tokenizer,
    save_path=new_model_name,
    adapter_path=adapter_path,
    de_quantize=False,
    export_gguf=False,
)

In [None]:
grpo_dataset = load_dataset(grpo_dataset_name)["train"]
train_dataset, valid_dataset = grpo_dataset.train_test_split(test_size=0.01, seed=42).values()
train_set = GRPODataset(
    train_dataset,
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    system_key="system",
    type_key="type"
)
valid_set = GRPODataset(
    valid_dataset,
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    system_key="system",
    type_key="type"
)

In [None]:
def get_completion_content(completion):
    try:
        if isinstance(completion, str):
            return completion
        elif isinstance(completion, dict):
            return completion.get('content', '')
        elif isinstance(completion, list) and len(completion) > 0:
            first_item = completion[0]
            if isinstance(first_item, dict):
                return first_item.get('content', '')
            return str(first_item)
        return str(completion)
    except Exception:
        return ''

def get_prompt_content(prompt):
    try:
        if isinstance(prompt, str):
            return prompt
        elif isinstance(prompt, dict):
            return prompt.get('content', '')
        elif isinstance(prompt, list):
            last_item = prompt[-1]
            if isinstance(last_item, dict):
                return last_item.get('content', '')
            return str(last_item)
        return str(prompt)
    except Exception:
        return ''

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [get_completion_content(completion) for completion in completions]
    q = get_prompt_content(prompts[0])
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [get_completion_content(completion) for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"^<josie_thinks> .*? </josie_thinks> <josie_answers> .*? </josie_answers>\n$"
    responses = [get_completion_content(completion) for completion in completions]
    matches = [bool(re.search(pattern, r, re.DOTALL)) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"<josie_thinks>.*?</josie_thinks><josie_answers>.*?</josie_answers>"
    responses = [get_completion_content(completion) for completion in completions]
    matches = [bool(re.search(pattern, r, re.DOTALL)) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<josie_thinks>") == 1:
        count += 0.125
    if text.count("</josie_thinks>") == 1:
        count += 0.125
    if text.count("<josie_answers>") == 1:
        count += 0.125
        count -= len(text.split("</josie_answers>")[-1])*0.001
    if text.count("</josie_answers>") == 1:
        count += 0.125
        count -= (len(text.split("</josie_answers>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [get_completion_content(completion) for completion in completions]
    return [count_xml(c) for c in contents]