In [None]:
import os
from pathlib import Path

# Set the environment variables for HuggingFace
# This is done to ensure that the cache directory for HuggingFace is set to a specific location,
# preventing the storage from being overwhelmed with model files and other data.
SCRATCH = Path.home() / "scratch"
os.environ["HF_HOME"] = str(SCRATCH / "hf_home")

In [None]:
import sys
sys.path.append("/home/htkumar/torchtune/deep_rl/nano_aha_moment")

In [None]:
import gc
import re
import time
from typing import Any, Dict, List, Tuple, Union

import deepspeed
import numpy as np
import torch
from datasets import load_dataset
from deepspeed import DeepSpeedEngine
from tqdm import trange
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from vllm import LLM, SamplingParams

# TODO: Add deepspeed params if needed

In [None]:
# Hyperparameters
MODEL_NAME = "Qwen/Qwen2.5-3B"
MODEL_CHAT_NAME = MODEL_NAME + "-Instruct"

# Dataset configuration
DATASET_NAME = "Jiayi-Pan/Countdown-Tasks-3to4"

NUM_ITERATIONS = 1000
EPISODES_PER_ITERATION = 64
GENERATIONS_PER_SAMPLE = 4
KL_COEFFICIENT = 0.001

# actual batch size is 64, this is mbs so we are using grad_acc
PER_DEVICE_BATCH_SIZE = 4
LEARNING_RATE = 1e-6

# Sampling params
MAX_RESPONSE_TOKENS = 1024
TEMPERATURE = 1.0
TOP_P = 1.0 # disabled nuclear sampling
TOP_K = -1 # no top_k

# TODO: define deepspeed configs here if needed.

In [None]:
RUN_NAME = "r1-zero"
EXP_DIR = SCRATCH / "deepseek_r1_replica" / RUN_NAME
EXP_DIR.mkdir(parents=True, exist_ok=True)
EXP_DIR

In [None]:
from prompt_utils import (
    DEFAULT_SYSTEM_MESSAGE,
    DEFAULT_PROMPT_TEMPLATE
)

In [None]:
# We use the chat model tokenizer so that we can use `apply_chat_template` to the prompt
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(MODEL_CHAT_NAME)
EOS_TOKEN_ID = AutoTokenizer.from_pretrained(MODEL_NAME).eos_token_id
EOS_TOKEN = tokenizer.convert_ids_to_tokens(EOS_TOKEN_ID)
EOS_TOKEN_ID, EOS_TOKEN

In [None]:
# tokenizer

In [None]:
def preprocess_countdown_example(example: Dict[str, Any]):
    numbers: List[int] = example["nums"]
    target: int = example["target"]
    prompt = DEFAULT_PROMPT_TEMPLATE.format(numbers=numbers, target=target)

    chat_messages = [
        {"role": "system", "content": DEFAULT_SYSTEM_MESSAGE},
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": "Let me think step by step\n<think>"},
    ]

    input_ids = tokenizer.apply_chat_template(
        chat_messages, tokenize=True, continue_final_message=True
    )
    prompt = tokenizer.decode(
        input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
    )
    return {
        "input_ids": input_ids,
        "prompt": prompt,
    }

In [None]:
dataset = load_dataset(DATASET_NAME, split='train')
dataset = dataset.map(preprocess_countdown_example, num_proc=8)

In [None]:
len(dataset)

In [None]:
dataset[0]['prompt']

In [None]:
train_test_split = dataset.train_test_split(test_size=500, seed=42)
train_dataset = train_test_split['train']
test_dataset = train_test_split['test']
len(train_dataset), len(test_dataset)

In [None]:
train_dataset[0]['nums']

In [None]:
train_dataset[0]['target']

In [None]:
EOS_TOKEN = "<|endoftext|>"

In [None]:
def format_reward_func(completion: str) -> float:
    """
    Format: <think>...</think>\n<answer>...</answer>
    """
    allowed_pattern = r"^[\d+\-*/().\s]+$"
    try:
        completion = "<think>" + completion

        if completion.endswith(EOS_TOKEN):
            completion = completion[:-len(EOS_TOKEN)]

        regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
        match = re.search(regex, completion, re.DOTALL)
        if match is None or len(match.groups()) != 2:
            return 0.0
        else:
            answer_content = match.group(2).strip()
            if not re.match(allowed_pattern, answer_content):
                return 0.5
            else:
                return 1.0

    except Exception:
        return 0.0


In [None]:
EOS_TOKEN

In [None]:
format_reward_func(
    """Using the numbers [4, 3, 56, 41], create an equation that equals 97.</think>
First, I'll add 41 to 48 (56 + 4 - 6) to get 48. Then, I'll subtract 3 (since it's leftover) to get 45. Now, I have 48 and 45, which add up to 93. So, I need another 7 to get 97. I know that 7 is 14/2, so I'll multiply 48 by 14 which equals 672, then divide it by 2 (i.e. 672 / 2). Finally, I'll subtract it from 97 to achieve that difference of 7. Therefore, the final equation is <answer>96 - (48 * (672 / 2)) = 97</answer>.â‹…<|endoftext|>"""
)

In [None]:
format_reward_func("I am thinking </think>\n<answer>abcd</answer>")

In [None]:
format_reward_func("I am thinking </think>\n<answer>1+2</answer>")

In [None]:
format_reward_func("I am <thinking> </think>\n<answer>1+2</answer>")

In [None]:
format_reward_func("I am <think> </think>\n<answer>1+2</answer>")

In [None]:
def equation_reward_func(completion: str, nums: List[int], target: int) -> float:
    try:
        match = re.search(r"<answer>(.*?)<\/answer>", completion)
        if match is None:
            return 0.0

        equation = match.group(1).strip()
        used_numbers = [int(n) for n in re.findall(r"\d+", equation)]

        if sorted(used_numbers) != sorted(nums):
            return 0.0

        allowed_pattern = r"^[\d+\-*/().\s]+$"
        if not re.match(allowed_pattern, equation):
            return 0.0

        result = eval(equation, {"__builtins__": None}, {})
        if abs(float(result) - float(target)) < 1e-6:
            return 1.0
        else:
            return 0.0

    except Exception:
        return 0.0

In [None]:
def compute_reward(completion: str, sample: Dict[str, Any]) -> Tuple[float, Dict[str, float]]:
    format_reward = format_reward_func(completion)
    equation_reward = equation_reward_func(completion, sample['nums'], sample['target'])
    # todo: make this weighted?
    reward = 1.0 * format_reward + 1.0 * equation_reward
    metrics = {
        'format_reward': format_reward,
        "equation_reward": equation_reward,
    }
    return reward, metrics

In [None]:
from reward_functions import *

In [None]:
equation_reward_func("I am thinking </think><answer>1+2</answer>", [1, 2], 3)

In [None]:
equation_reward_func("I am thinking </think><answer>1+2+2</answer>", [1, 2], 3)

In [None]:
equation_reward_func("I am thinking </think><answer>1+4</answer>", [1, 2], 3)

In [None]:
samples = [{"input_ids": [1,2,3], "nums": [1,2,3], "target": 6}]
all_generations = [[4,5, EOS_TOKEN_ID], [6,7], [8,9, EOS_TOKEN_ID], [10, 11]]  # 3 generations per sample

In [None]:
groups = [
        list(range(i, i + GENERATIONS_PER_SAMPLE))
        for i in range(0, len(all_generations), GENERATIONS_PER_SAMPLE)
    ]

In [None]:
groups

In [None]:
all_query_token_ids = [
        [sample["input_ids"]] * GENERATIONS_PER_SAMPLE for sample in samples
    ]

In [None]:
generation_strs = tokenizer.batch_decode(
        all_generations, skip_special_tokens=False, clean_up_tokenization_spaces=False
    )

In [None]:
generation_strs

In [None]:
generations_str_grouped = [[generation_strs[i] for i in group] for group in groups]

In [None]:
generations_str_grouped

In [None]:
rewards = [
        [compute_reward(generation_str, sample) for generation_str in generations]
        for sample, generations in zip(samples, generations_str_grouped)
    ]

In [None]:
rewards = [
        [compute_reward(generation_str, sample)[0] for generation_str in generations]
        for sample, generations in zip(samples, generations_str_grouped)
    ]

In [None]:
rewards

In [None]:
rewards = np.array(rewards)

In [None]:
rewards.mean(), rewards.std()

In [None]:
a = np.array(rewards)

In [None]:
a.mean(), a.std()

In [None]:
arr = [
    [1, 2, 3, 4],
    [5, 6, 7, 8]
]

In [None]:
b = np.array(arr)

In [None]:
from utils import create_training_episodes

In [None]:
case_0 = {
    "samples": [{"input_ids": [1,2,3], "nums": [1,2,3], "target": 6}],
    "all_generations": [[4,5, 22, 33], [6,7], [8,9, 11], [10,11]],
    "all_finish_reasons": ["stop", "length", "stop", "stop"]
}
create_training_episodes(tokenizer, **case_0)

In [None]:
case_1 = {
    "samples": [{"input_ids": [1,2,3], "nums": [1,2,3], "target": 6}, {"input_ids": [9, 8, 7, 6, 5, 4], "nums": [1,2,3,4], "target": 10}],
    "all_generations": [[4,5, 22, 33], [6,7], [8,9, 11], [10,11], [9,10], [11,12], [13,14], [15,16]],
    "all_finish_reasons": ["stop", "length", "stop", "stop", "length", "length", "stop", "stop"]
}

In [None]:
episodes, stats = create_training_episodes(tokenizer, **case_1)

In [None]:
print(create_training_episodes(tokenizer, **case_1)[0]['all_advantages'])

In [None]:
from utils import prepare_model_inputs

In [None]:
episodes

In [None]:
def prepare_model_inputs1(
    training_episodes: Dict[str, Any], device: torch.device
) -> Dict[str, torch.tensor]:
    query_token_ids = training_episodes["all_query_token_ids"]
    response_token_ids = training_episodes["all_response_token_ids"]
    advantages = training_episodes["all_advantages"]
    print(len(query_token_ids), len(response_token_ids), len(advantages))

    max_seq_len = max(
        len(q) + len(r) for q, r in zip(query_token_ids, response_token_ids)
    )
    input_ids, attention_mask, labels, advantage_list = [], [], [], []
    pad_token_id = 0
    ignore_index = -100  # check nn.CrossEntropyLoss for more context

    for q, r, a in zip(query_token_ids, response_token_ids, advantages):
        # print(q)
        combined_ids = q + r
        seq_len = len(combined_ids)
        # print(f"seq_len is {len(seq_len)}")
        input_ids.append(combined_ids + [pad_token_id] * (max_seq_len - seq_len))
        attention_mask.append([1] * seq_len + [0] * (max_seq_len - seq_len))
        labels.append(
            [ignore_index] * len(q) + r + [ignore_index] * (max_seq_len - seq_len)
        )
        advantage_list.append([0.0] * len(q) + a + [0.0] * (max_seq_len - seq_len))

    print(len(input_ids))

    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long, device=device),
        "attention_mask": torch.tensor(attention_mask, dtype=torch.long, device=device),
        "labels": torch.tensor(labels, dtype=torch.long, device=device),
        "advantages": torch.tensor(advantage_list, dtype=torch.float, device=device),
    }

In [None]:
prepare_model_inputs(
    episodes, "cuda"
)

In [None]:
logits = torch.ones(4, 4, 8)
logits.shape

In [None]:
torch.softmax(logits, dim=-1).shape

In [None]:
torch.log_softmax(logits, dim=-1).shape

In [None]:
torch.log(torch.softmax(logits, dim=-1))[0, 0, 0]

In [None]:
a = torch.tensor([0.1, 0.5, 0.9])
torch.log(a)

In [None]:
def compute_token_log_probs(
    model: PreTrainedModel,
    inputs: Dict[str, torch.tensor],
    temperature: float,
) -> torch.tensor:
    """
    Compute the log probabilities of the next token given the input sequence.
    """
    model_output = model(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        return_dict=True,
        use_cache=False,
    )
    logits = model_output.logits.float() / temperature
    shift_logits = logits[..., :-1, :].contiguous()  # [bsz, seq_len-1, vocab_size]
    shift_labels = inputs["labels"][..., 1:].contiguous()  # [bsz, seq_len-1]

    label_mask = (shift_labels != -100).float()
    shift_labels[shift_labels == -100] = 0

    # [bsz, seq_len-1, vocab_size]
    log_probs = torch.log_softmax(shift_logits, dim=-1)  # log(softmax(logits))

    # [bsz, seq_len-1]
    log_probs = torch.gather(
        log_probs, dim=-1, index=shift_labels.unsqueeze(-1)
    ).squeeze(-1)
    # [bsz, seq_len-1]
    log_probs = log_probs * label_mask
    return log_probs

In [None]:
def compute_pg_loss(
    policy_model: PreTrainedModel,
    reference_model: PreTrainedModel,
    input_batch: Dict[str, torch.tensor],
    total_response_len: int,
) -> Tuple[torch.tensor, Dict[str, float]]:
    """
    Compute the policy gradient loss for the policy model by combining PPO loss and KL penalty.
    """
    # inputs are dim [bsz, seq_len]

    # [bsz, seq_len-1]
    with torch.no_grad():
        ref_model_logprobs = compute_token_log_probs(
            reference_model, input_batch, TEMPERATURE
        )

    policy_model_logprobs = compute_token_log_probs(
        policy_model, input_batch, TEMPERATURE
    )
    diff = ref_model_logprobs - policy_model_logprobs
    kl_distance = torch.exp(diff) - diff - 1
    policy_loss = (
        -policy_model_logprobs * input_batch["advantages"][..., 1:]
    )  # [bsz, seq_len-1]
    loss = (
        policy_loss + KL_COEFFICIENT * kl_distance
    ).sum() / total_response_len  # scalar

    metrics = {
        "policy_loss": policy_loss.sum().item() / total_response_len,
        "kl_distance": kl_distance.sum().item() / total_response_len,
        # entropy should decrease over time as the policy becomes more certain, for reference model it should stay the same over time.
        "entropy_policy": -policy_model_logprobs.sum().item() / total_response_len,
        "entropy_ref": -ref_model_logprobs.sum().item() / total_response_len,
    }
    return loss, metrics

### Training code

In [None]:
policy_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map=0,
)
reference_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map=0,
)

In [None]:
policy_model

In [None]:
policy_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

In [None]:
inference_engine = LLM(
    model=MODEL_NAME,
    skip_tokenizer_init=False,
    enable_prefix_caching=True,
    swap_space=1,
    scheduling_policy='fcfs',
    dtype=torch.bfloat16,
    max_model_len=2048,
    enable_sleep_mode=True,
)