# Open-R1-Distill

- [huggingface/open-r1][2]
- [open-r1/Mixture-of-Thoughts][1]

[1]: https://huggingface.co/datasets/open-r1/Mixture-of-Thoughts
[2]: https://github.com/huggingface/open-r1

## Áí∞Â¢ÉÊßãÁØâ

In [None]:
import logging
import os
import subprocess
import sys

if os.path.exists("debug.log"):
    os.remove("debug.log")

def custom_format(record):
    match record.levelno:
        case logging.DEBUG:
            level = "üü¶"
        case logging.INFO:
            level = "üü©"
        case logging.WARNING:
            level = "üü®"
        case logging.ERROR:
            level = "üü•"
        case logging.CRITICAL:
            level = "üõë"
    return f"{level} {record.getMessage()}"

logger = logging.getLogger()

for handler in logger.handlers:
    logger.removeHandler(handler)

formatter = logging.Formatter()
formatter.format = custom_format

file_handler = logging.FileHandler("debug.log")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

logger.setLevel(logging.DEBUG)

NVIDIA_SMI = subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout
logging.info(NVIDIA_SMI)
logging.info(f"Python {sys.version}")

In [None]:
# Áí∞Â¢ÉÂ§âÊï∞„ÅÆË®≠ÂÆö
# VLLM_USE_V1=0 „ÇíË®≠ÂÆö
import os
os.environ["VLLM_USE_V1"] = "0"

In [None]:
try:
    import google.colab
    IN_COLAB = True
    if not os.path.exists("/content/open-r1"):
        %git clone https://github.com/huggingface/open-r1.git
    %cd /content/open-r1
    %pip install -e ".[dev]" --no-deps
except ImportError:
    IN_COLAB = False
    !apt update && apt install git-lfs -y
    if not os.path.exists("/workspaces/open-r1-distill/open-r1"):
        %git clone https://github.com/huggingface/open-r1.git
    %cd /workspaces/open-r1-distill/open-r1
    %pip install -e ".[dev]" --no-deps

In [None]:
# 1) PyTorch„Å®Transformers„ÅÆ„Ç§„É≥„Çπ„Éà„Éº„É´
%pip install torch==2.6.0 transformers==4.52.3

# 2) vLLM„ÅÆ„Ç§„É≥„Çπ„Éà„Éº„É´
%pip install vllm==0.8.5.post1

# 3) Flash Attention„ÅÆ„Ç§„É≥„Çπ„Éà„Éº„É´
# 2.8.3„ÅØundefined symbol„Ç®„É©„Éº„ÅåÁô∫Áîü„Åô„Çã„Åü„ÇÅ2.7.3„Çí„Ç§„É≥„Çπ„Éà„Éº„É´
# https://github.com/Dao-AILab/flash-attention/issues/1832
%pip install "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl" --no-build-isolation

# 4) „Åù„ÅÆ‰ªñ„ÅÆÂøÖË¶Å„Å™„Éë„ÉÉ„Ç±„Éº„Ç∏„ÅÆ„Ç§„É≥„Çπ„Éà„Éº„É´
%pip install \
    accelerate==1.4.0 \
    async-lru \
    bitsandbytes \
    "distilabel[vllm]" \
    deepspeed==0.16.8 \
    hf_transfer \
    langdetect \
    latex2sympy2_extended \
    liger-kernel \
    "trl[vllm]==0.18.0" \
    math-verify==0.5.2 \
    wandb

In [None]:
%pip install beautifulsoup4

## Data Generation

In [None]:
# if IN_COLAB:
#     %cd /content/open-r1
# else:
#     %cd /workspaces/open-r1-distill/open-r1

# !python pipeline.py

In [None]:
from datasets import load_dataset
from distilabel.models import vLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration

In [None]:
# „Çπ„ÉÜ„ÉÉ„Éó„Éê„Ç§„Çπ„ÉÜ„ÉÉ„Éó„ÅßÊé®Ë´ñ„Åó„ÄÅÊúÄÁµÇÂõûÁ≠î„Çí\boxed{}„ÅßÂõ≤„Å£„Å¶„Åè„Å†„Åï„ÅÑ
prompt_template = """\
You will be given a problem. Please reason step by step, and put your final answer within \boxed{}:
{{ instruction }}"""

In [None]:
# Êï∞Â≠¶„Éá„Éº„Çø„Çª„ÉÉ„Éà
# https://huggingface.co/datasets/AI-MO/NuminaMath-TIR

dataset = load_dataset(
    "AI-MO/NuminaMath-TIR",
    split="train",
).select(range(10))

len(dataset), dataset[0].keys()

In [None]:
print(dataset[0]["problem"])

In [None]:
print(dataset[0]["solution"])

In [None]:
# Exchange with another smol distilled r1
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

In [None]:
# „Éë„Ç§„Éó„É©„Ç§„É≥„ÅÆÂÆöÁæ©„Å®LLM„ÅÆË®≠ÂÆö

with Pipeline(
    name="distill-qwen-7b-r1",
    description="A pipeline to generate data from a distilled r1 model",
) as pipeline:

    # distilabel„ÅÆvLLM„ÇíÂàùÊúüÂåñ
    # https://distilabel.argilla.io/dev/components-gallery/llms/vllm/?h=vllm
    llm = vLLM(
        model=model_id,
        tokenizer=model_id,
        extra_kwargs={
            "tensor_parallel_size": 1,
            # "max_model_len": 8192,
            "max_model_len": 1024,
        },
        generation_kwargs={
            "temperature": 0.6,
            # "max_new_tokens": 8192,
            "max_new_tokens": 1024,
        },
    )

    prompt_column = "problem"

    text_generation = TextGeneration(
        llm=llm, 
        template=prompt_template,
        num_generations=4,
        input_mappings={"instruction": prompt_column} if prompt_column is not None else {}
    )

distiset = pipeline.run(dataset=dataset)

In [None]:
distiset.save_to_disk("sample")

In [None]:
from distilabel.distiset import Distiset

ds = Distiset.load_from_disk("sample")

In [None]:
distiset["default"]["train"][0].keys()

In [None]:
print(distiset["default"]["train"][0]["model_name"])

In [None]:
print(distiset["default"]["train"][0]["distilabel_metadata"])

In [None]:
print(distiset["default"]["train"][0]["generation"])

## SFT

In [None]:
# !python src/open_r1/sft.py \
#     --config recipes/OpenR1-Distill-7B/sft/config_distill.yaml \
#     --model_name_or_path Qwen/Qwen3-0.6B-Base \
#     --hub_model_id OpenR1-Distill-0.6B \
#     --output_dir data/OpenR1-Distill-0.6B \
#     --push_to_hub False \
#     --report_to none

In [None]:
import logging
import os
import sys

import datasets
import transformers
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint

from open_r1.configs import ScriptArguments, SFTConfig
from open_r1.utils import get_dataset, get_model, get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format

In [None]:
# model_name_or_path: open-r1/Qwen2.5-Math-7B-RoPE-300k
# model_revision: main
# torch_dtype: bfloat16
# attn_implementation: flash_attention_2

model_args = ModelConfig(
    # model_name_or_path="open-r1/Qwen2.5-Math-7B-RoPE-300k",
    model_name_or_path="Qwen/Qwen3-0.6B-Base",
    model_revision="main",
    torch_dtype="bfloat16",
    attn_implementation="flash_attention_2",
)

In [None]:
training_args = SFTConfig(
    bf16=True,
    do_eval=False,
    eval_strategy="no",
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    hub_model_id="OpenR1-Distill-0.6B", # "OpenR1-Distill-7B"
    hub_strategy="every_save",
    learning_rate=4.0e-5,
    log_level="info",
    logging_steps=1,
    logging_strategy="steps",
    lr_scheduler_type="cosine_with_min_lr",
    lr_scheduler_kwargs={"min_lr_rate": 0.1},
    packing=False,
    max_grad_norm=0.2,
    max_length=32768,
    max_steps=-1,
    num_train_epochs=1, # 5
    output_dir="data/OpenR1-Distill-0.6B", # "data/OpenR1-Distill-7B"
    overwrite_output_dir=True,
    per_device_eval_batch_size=1,
    per_device_train_batch_size=2,
    push_to_hub=False, # True
    report_to=[], # ["wandb"]
    save_strategy="epoch",
    save_total_limit=1,
    seed=42,
    use_liger_kernel=True,
    warmup_ratio=0.03,
    dataset_num_proc=12,
    eos_token="<|im_end|>",
)

In [None]:
# dataset_name: open-r1/Mixture-of-Thoughts
# dataset_config: all
# dataset_num_proc: 12
# eos_token: <|im_end|>

script_args = ScriptArguments(
    dataset_name="open-r1/Mixture-of-Thoughts",
    dataset_config="all",
)

In [None]:
set_seed(training_args.seed)

In [None]:
###############
# Setup logging
###############
# logging.basicConfig(
#     format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
#     datefmt="%Y-%m-%d %H:%M:%S",
#     handlers=[logging.StreamHandler(sys.stdout)],
# )
# log_level = training_args.get_process_log_level()
# logger.setLevel(log_level)
# datasets.utils.logging.set_verbosity(log_level)
# transformers.utils.logging.set_verbosity(log_level)
# transformers.utils.logging.enable_default_handler()
# transformers.utils.logging.enable_explicit_format()

# logger.info(f"Model parameters {model_args}")
# logger.info(f"Script parameters {script_args}")
# logger.info(f"Training parameters {training_args}")

# Check for last checkpoint
last_checkpoint = None

if os.path.isdir(training_args.output_dir):
    last_checkpoint = get_last_checkpoint(training_args.output_dir)

if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
    logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

if "wandb" in training_args.report_to:
    init_wandb_training(training_args)

In [None]:
######################################
# Load dataset, tokenizer, and model #
######################################
dataset = get_dataset(script_args)
dataset

In [None]:
# „Çµ„Éñ„Çµ„É≥„Éó„É™„É≥„Ç∞
for split in dataset.keys():
    dataset[split] = dataset[split].select(range(1000))

dataset

In [None]:
tokenizer = get_tokenizer(model_args, training_args)
tokenizer

In [None]:
if tokenizer.chat_template is None:
    logger.info("No chat template provided, defaulting to ChatML.")
    model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")

tokenizer.chat_template

In [None]:
model = get_model(model_args, training_args)

In [None]:
model

In [None]:
############################
# Initialize the SFT Trainer
############################

training_args.max_steps = 10

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset[script_args.dataset_train_split],
    eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None),
    processing_class=tokenizer,
    peft_config=get_peft_config(model_args),
    callbacks=get_callbacks(training_args, model_args),
)

In [None]:
###############
# Training loop
###############
checkpoint = None

if training_args.resume_from_checkpoint is not None:
    checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
    checkpoint = last_checkpoint

checkpoint

In [None]:
train_result = trainer.train(resume_from_checkpoint=checkpoint)

In [None]:
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

In [None]:
del trainer
del model
del dataset

import gc
gc.collect()

In [None]:
import torch
torch.cuda.empty_cache()

## GRPO

In [None]:
# if IN_COLAB:
#     %cd /content/open-r1
# else:
#     %cd /workspaces/open-r1-distill/open-r1

# !RANK=0 WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=localhost MASTER_PORT=12345 python src/open_r1/grpo.py \
#     --config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml \
#     --model_name_or_path Qwen/Qwen3-0.6B-Base \
#     --hub_model_id OpenR1-Distill-0.6B \
#     --output_dir data/OpenR1-Distill-0.6B \
#     --push_to_hub False \
#     --report_to none \
#     --vllm_mode colocate \
#     --per_device_train_batch_size 1 \
#     --num_generations 4 \
#     --gradient_accumulation_steps 8

In [None]:
import logging
import os
import sys

import datasets
import transformers
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint

from open_r1.configs import GRPOConfig, GRPOScriptArguments
from open_r1.rewards import get_reward_funcs
from open_r1.utils import get_dataset, get_model, get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, TrlParser, get_peft_config

In [None]:
import os

# !RANK=0 WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=localhost MASTER_PORT=12345 
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"

In [None]:
script_args = GRPOScriptArguments(
    dataset_name="open-r1/OpenR1-Math-220k",
    dataset_prompt_column="problem",
    reward_funcs=["accuracy", "format", "tag_count"],
)

In [None]:
training_args = GRPOConfig(
    bf16=True,
    use_vllm=True,
    do_eval=False,
    # gradient_accumulation_steps=4,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    # hub_model_id="DeepSeek-R1-Distill-Qwen-1.5B-GRPO",
    hub_model_id="DeepSeek-R1-Distill-Qwen-0.6B-GRPO",
    hub_strategy="every_save",
    learning_rate=1.0e-6,
    log_completions=True,
    log_level="info",
    logging_first_step=True,
    logging_steps=1,
    logging_strategy="steps",
    lr_scheduler_type="cosine_with_min_lr",
    lr_scheduler_kwargs={"min_lr_rate": 0.1},
    max_prompt_length=512,
    max_completion_length=2048,
    max_steps=-1,
    # num_generations=16,
    num_generations=4,
    num_train_epochs=1,
    # output_dir="data/DeepSeek-R1-Distill-Qwen-1.5B-GRPO",
    output_dir="data/DeepSeek-R1-Distill-Qwen-0.6B-GRPO",
    overwrite_output_dir=True,
    # per_device_eval_batch_size=16,
    per_device_eval_batch_size=1,
    # per_device_train_batch_size=16,
    per_device_train_batch_size=1,
    # push_to_hub=True,
    push_to_hub=False,
    # report_to=["wandb"],
    report_to=[],
    # reward_funcs=["accuracy", "format", "tag_count"],
    reward_weights=[1.0, 1.0, 1.0],
    save_strategy="epoch",
    save_total_limit=1,
    seed=42,
    temperature=0.7,
    use_liger_kernel=True,
    warmup_ratio=0.1,
    system_prompt="You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>",
    # chat_template="",
    vllm_mode="colocate",
)

In [None]:
# model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
# model_revision: main
# torch_dtype: bfloat16
# attn_implementation: flash_attention_2

model_args = ModelConfig(
    # model_name_or_path="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    model_name_or_path="Qwen/Qwen3-0.6B-Base",
    model_revision="main",
    torch_dtype="bfloat16",
    attn_implementation="flash_attention_2",
)

In [None]:
set_seed(training_args.seed)

In [None]:
###############
# Setup logging
###############
# logging.basicConfig(
#     format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
#     datefmt="%Y-%m-%d %H:%M:%S",
#     handlers=[logging.StreamHandler(sys.stdout)],
# )
# log_level = training_args.get_process_log_level()
# logger.setLevel(log_level)
# datasets.utils.logging.set_verbosity(log_level)
# transformers.utils.logging.set_verbosity(log_level)
# transformers.utils.logging.enable_default_handler()
# transformers.utils.logging.enable_explicit_format()

# # Log on each process a small summary
# logger.warning(
#     f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
#     + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
# )
# logger.info(f"Model parameters {model_args}")
# logger.info(f"Script parameters {script_args}")
# logger.info(f"Training parameters {training_args}")

In [None]:
# Check for last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
    logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

if "wandb" in training_args.report_to:
    init_wandb_training(training_args)

In [None]:
# Load the dataset
dataset = get_dataset(script_args)

In [None]:
dataset

In [None]:
for split in dataset.keys():
    dataset[split] = dataset[split].select(range(1000))  # „Éá„Éê„ÉÉ„Ç∞Áî®„Å´1000„Çµ„É≥„Éó„É´„Å´Âà∂Èôê

dataset

In [None]:
################
# Load tokenizer
################
tokenizer = get_tokenizer(model_args, training_args)
tokenizer

In [None]:
model = get_model(model_args, training_args)
model

In [None]:
# Get reward functions from the registry
reward_funcs = get_reward_funcs(script_args)
reward_funcs

In [None]:
# Format into conversation
def make_conversation(example, prompt_column: str = script_args.dataset_prompt_column):
    prompt = []

    if training_args.system_prompt is not None:
        prompt.append({"role": "system", "content": training_args.system_prompt})

    if prompt_column not in example:
        raise ValueError(f"Dataset Question Field Error: {prompt_column} is not supported.")

    prompt.append({"role": "user", "content": example[prompt_column]})
    return {"prompt": prompt}

dataset = dataset.map(make_conversation)
dataset

In [None]:
for split in dataset:
    if "messages" in dataset[split].column_names:
        dataset[split] = dataset[split].remove_columns("messages")

dataset

In [None]:
#############################
# Initialize the GRPO trainer
#############################
training_args.max_steps = 10

trainer = GRPOTrainer(
    model=model,
    reward_funcs=reward_funcs,
    args=training_args,
    train_dataset=dataset[script_args.dataset_train_split],
    eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None),
    peft_config=get_peft_config(model_args),
    callbacks=get_callbacks(training_args, model_args),
    processing_class=tokenizer,
)

In [None]:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
    checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
    checkpoint = last_checkpoint

In [None]:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

In [None]:
##################################
# Save model and create model card
##################################
# Align the model's generation config with the tokenizer's eos token
# to avoid unbounded generation in the transformers `pipeline()` function
trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id
trainer.save_model(training_args.output_dir)

# Save everything else on main process
kwargs = {
    "dataset_name": script_args.dataset_name,
    "tags": ["open-r1"],
}
if trainer.accelerator.is_main_process:
    trainer.create_model_card(**kwargs)
    # Restore k,v cache for fast inference
    trainer.model.config.use_cache = True
    trainer.model.config.save_pretrained(training_args.output_dir)

In [None]:
##########
# Evaluate
##########
if training_args.do_eval:
    metrics = trainer.evaluate()
    metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

In [None]:
#############
# push to hub
#############
if training_args.push_to_hub:
    logger.info("Pushing to hub...")
    trainer.push_to_hub(**kwargs)
