In [None]:
import os
from pathlib import Path

# Set the environment variables for HuggingFace
# This is done to ensure that the cache directory for HuggingFace is set to a specific location,
# preventing the storage from being overwhelmed with model files and other data.
SCRATCH = Path.home() / "scratch"
os.environ["HF_HOME"] = str(SCRATCH / "hf_home")

In [None]:
import sys
sys.path.append("/home/htkumar/torchtune/deep_rl/nano_aha_moment")

In [None]:
import gc
import re
import time
from typing import Any, Dict, List, Tuple, Union

import deepspeed
import numpy as np
import torch
from datasets import load_dataset
from deepspeed import DeepSpeedEngine
from tqdm import trange
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from vllm import LLM, SamplingParams

# TODO: Add deepspeed params if needed

In [None]:
# Hyperparameters
MODEL_NAME = "Qwen/Qwen2.5-3B"
MODEL_CHAT_NAME = MODEL_NAME + "-Instruct"

# Dataset configuration
DATASET_NAME = "Jiayi-Pan/Countdown-Tasks-3to4"

NUM_ITERATIONS = 1000
EPISODES_PER_ITERATION = 64
GENERATIONS_PER_SAMPLE = 4
KL_COEFFICIENT = 0.001

# actual batch size is 64, this is mbs so we are using grad_acc
PER_DEVICE_BATCH_SIZE = 4
LEARNING_RATE = 1e-6

# Sampling params
MAX_RESPONSE_TOKENS = 1024
TEMPERATURE = 1.0
TOP_P = 1.0 # disabled nuclear sampling
TOP_K = -1 # no top_k

# TODO: define deepspeed configs here if needed.

In [None]:
RUN_NAME = "r1-zero"
EXP_DIR = SCRATCH / "deepseek_r1_replica" / RUN_NAME
EXP_DIR.mkdir(parents=True, exist_ok=True)
EXP_DIR

In [None]:
from prompt_utils import (
    SYSTEM_MESSAGE,
    PROMPT_TEMPLATE
)

In [None]:
# We use the chat model tokenizer so that we can use `apply_chat_template` to the prompt
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(MODEL_CHAT_NAME)
EOS_TOKEN_ID = AutoTokenizer.from_pretrained(MODEL_NAME).eos_token_id
EOS_TOKEN = tokenizer.convert_ids_to_tokens(EOS_TOKEN_ID)
EOS_TOKEN_ID, EOS_TOKEN

In [None]:
def preprocess_countdown_example(example: Dict[str, Any]):
    numbers: List[int] = example["nums"]
    target: int = example["target"]
    prompt = PROMPT_TEMPLATE.format(numbers=numbers, target=target)

    chat_messages = [
        {"role": "system", "content": SYSTEM_MESSAGE},
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": "Let me think step by step\n<think>"},
    ]

    input_ids = tokenizer.apply_chat_template(
        chat_messages, tokenize=True, continue_final_message=True
    )
    prompt = tokenizer.decode(
        input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
    )
    return {
        "input_ids": input_ids,
        "prompt": prompt,
    }

In [None]:
dataset = load_dataset(DATASET_NAME, split='train')
dataset = dataset.map(preprocess_countdown_example, num_proc=8)

In [None]:
len(dataset)

In [None]:
dataset[0]['prompt']

In [None]:
train_test_split = dataset.train_test_split(test_size=500, seed=42)
train_dataset = train_test_split['train']
test_dataset = train_test_split['test']
len(train_dataset), len(test_dataset)

In [None]:
train_dataset[0]['nums']

In [None]:
train_dataset[0]['target']