In [None]:
!pip install -U 'bitsandbytes<0.46.0' --quiet
!pip install -U flash-attn --no-build-isolation --quiet
!pip install 'git+https://github.com/nmecklenburg/transformers.git@nmeck/proj-changes#egg=transformers[torch]' --quiet
!pip install -U 'vllm<0.9.0' --quiet
# !pip install -U 'trl<0.18.0' --quiet
!pip install 'git+https://github.com/nmecklenburg/trl.git@nmeck/proj-changes#egg=trl' --quiet
!pip install -U 'tensorboard<2.19' --quiet

In [None]:
from datasets import concatenate_datasets, load_dataset, Dataset
from itertools import islice
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import (
    GRPOConfig,
    GRPOTrainer,
    SFTConfig,
    SFTTrainer,
    get_peft_config,
)

import bitsandbytes
import re
import shutil
import torch

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Countdown Dataset Synthesis

In [None]:
import operator
import random

from sympy import sympify

result = sympify("1 + 4 * 7 + 4 / 2")
print(result)

In [None]:
MAX_NUMBER = 100

def get_factors(n):
  factors = []
  divisor = int(n ** 0.5) + 1
  while divisor > 1:
    if n % divisor == 0:
      factors.append(divisor)
      factors.append(n // divisor)
    divisor -= 1
  return list(set(factors))

def is_prime(n):
  if n <= 3:
    return True

  # Check if n is composite
  prime = True
  for i in range(2, int(n**0.5) + 1):
    if n % i == 0:
      prime = False
      break

  return prime

def get_random_divisor(n):
  factors = get_factors(n)
  return random.choice(factors)

def get_random_number(op=None, last_num=None):
  if op == "/":
    assert last_num is not None
    return get_random_divisor(last_num)
  return random.randint(1, MAX_NUMBER)

def get_composite_number():
  num = random.randint(1, MAX_NUMBER)
  while is_prime(num):
    num = random.randint(1, MAX_NUMBER)
  return num

sym_to_op = {
    "+": operator.add,
    "-": operator.call,
    "*": operator.mul,
    "/": operator.floordiv
}

In [None]:
NUM_EXAMPLES_TO_GENERATE = 10000
MIN_OPERATORS, MAX_OPERATORS = 2, 3
VALID_OPERATORS = ["+", "-", "*", "/"]

bank = [] # (result, factors, equation)

while len(bank) < NUM_EXAMPLES_TO_GENERATE:
  formula = [get_composite_number()]
  num_operators = random.randint(MIN_OPERATORS, MAX_OPERATORS)

  for _ in range(num_operators):
    op = random.choice(VALID_OPERATORS)
    if is_prime(formula[-1]) and op == "/":
      # Only divide composite numbers.
      break
    else:
      num = get_random_number(op=op, last_num=formula[-1])
    formula.extend([op, num])

  if len(formula) < (2 * num_operators + 1):
    continue

  formula_expr = ' '.join([str(f) for f in formula])
  result = sympify(formula_expr)
  if result > 0 and result < 2000 and str(result).isnumeric():
    bank.append((result, sorted([f for f in formula if isinstance(f, int)], key=lambda _: random.random()), formula_expr))

In [None]:
bank[0]

# Datasets

In [None]:
PRETRAINED_MODEL_NAME = "Qwen/Qwen2-0.5B-Instruct"
# PRETRAINED_MODEL_NAME = "Microsoft/Phi-3.5-mini-instruct"
# PRETRAINED_MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"

In [None]:
# ========= OLD FORMAT =========
# countdown_prompt_template = (
#     "You will be given a math puzzle with a set of NUMBERS that can be combined with ('+', '-', '*', '/') to form some target RESULT."
#     "Think step by step and find the equation that satisfies the problem. Enclose your final formula in <answer></answer> tags."
#     "Example: [1, 2, 4] with a target of 2 -> <answer>1 * 4 - 2</answer>\n\nNUMBERS: {nums}\nRESULT: {result}\n"
# )

# def create_countdown_prompt_column(example):
#   example["prompt"] = countdown_prompt_template.format(nums=example['nums'], result=example['target'])
#   return example

In [None]:
COUNTDOWN_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\n"
    "User: Using the numbers {nums}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\n"
    "Assistant: Let me solve this step by step.\n<think>"
)


def reformat_o4m_synth_row(example):
  answer_start_idx = example['cot'].index("<answer>")
  example["completion"] = f"<think>{example['cot'][:answer_start_idx]}</think>{example['cot'][answer_start_idx:]}"
  example["prompt"] = COUNTDOWN_PROMPT.format(nums=example['numbers'], target=example['target'])
  del example['numbers'], example['target'], example['cot'], example['expression']
  return example


def prompt_to_entry(text):
  m = re.match(r"[\s\S]+Using the numbers \[(.*)\], create an equation that equals (\d+)", text)
  nums, target = m.group(1), m.group(2)
  entry = ((tuple(sorted(int(i) for i in nums.split(', ')))), int(target))
  return entry

def reformat_rft(example):
  example["prompt"] = COUNTDOWN_PROMPT.format(nums=example['nums'], target=example['target'])
  return example

def reformat_val(example):
  example["prompt"] += "\n<think>"
  return example

In [None]:
DATASET_TYPE = "COUNTDOWN"

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)
tokenizer.padding_side = 'left'


if DATASET_TYPE == "MATH":
  # We may not use all 50k, but if we increase/decrease the dataset during exp.,
  # give some wiggle room so we can be consistent with the validation set.
  N_TOTAL = 50_000
  N_SFT_TRAIN_EXAMPLES = 10_000
  N_RFT_TRAIN_EXAMPLES = 500
  assert(N_SFT_TRAIN_EXAMPLES + N_RFT_TRAIN_EXAMPLES < N_TOTAL)
  N_VALID_EXAMPLES = 1_000

  # https://huggingface.co/datasets/PrimeIntellect/verifiable-math-problems/viewer/default/train?views%5B%5D=train&row=4
  dataset_stream = load_dataset("PrimeIntellect/verifiable-math-problems", split="train", streaming=True)
  dataset_slice = list(islice(dataset_stream, N_TOTAL + N_VALID_EXAMPLES))
  sft_train_dataset = Dataset.from_list(dataset_slice[:N_SFT_TRAIN_EXAMPLES])
  rft_train_dataset = Dataset.from_list(dataset_slice[N_SFT_TRAIN_EXAMPLES:N_SFT_TRAIN_EXAMPLES + N_RFT_TRAIN_EXAMPLES])
  valid_dataset = Dataset.from_list(dataset_slice[N_TOTAL:N_TOTAL + N_VALID_EXAMPLES])

  avg_toks_per_example = 0
  for example in sft_train_dataset:
    avg_toks_per_example += len(tokenizer(example['prompt'], add_special_tokens=True)['input_ids']) + len(tokenizer(example['gold_standard_solution'], add_special_tokens=True)['input_ids'])
  avg_toks_per_example = int(avg_toks_per_example / N_SFT_TRAIN_EXAMPLES)

  sft_train_dataset = sft_train_dataset.rename_column("gold_standard_solution", "completion")
  sft_train_dataset = sft_train_dataset.remove_columns([col for col in sft_train_dataset.column_names if col not in {"prompt", "completion"}])

elif DATASET_TYPE == "COUNTDOWN":
  # === COUNTDOWN V1 ===
  # sft_train_dataset = load_dataset("json", data_files="/content/drive/MyDrive/cs224r/data/synth_countdown_sft_10k.jsonl", split="train")\
  #     .rename_columns({"numbers": "nums", "cot": "completion"})\
  #     .map(create_countdown_prompt_column)\
  #     .remove_columns(["target", "nums"])
  # rft_train_dataset = load_dataset("predibase/countdown", split=f"train[:{N_RFT_TRAIN_EXAMPLES}]")\
  #   .map(create_countdown_prompt_column)
  # valid_dataset = load_dataset("predibase/countdown", split=f"train[{N_RFT_TRAIN_EXAMPLES}:]")\
  #   .map(create_countdown_prompt_column)
  # test_dataset = load_dataset("predibase/countdown", split="test")\
  #   .map(create_countdown_prompt_column)

  synth_dataset = load_dataset(
      "json",
      data_files="/content/drive/MyDrive/cs224r/data/synth_countdown_sft_10k.jsonl",
      split="train"
  ).map(reformat_o4m_synth_row)
  synth_train_dataset = synth_dataset.select(range(9700))
  synth_valid_dataset = synth_dataset.select(range(9700, 10000))  # val set 200 -> 500

  predet_train_dataset = load_dataset("Asap7772/cog_behav_all_strategies", split="train")\
        .rename_columns({"query": "prompt"})
  sft_train_dataset = concatenate_datasets([synth_train_dataset, predet_train_dataset])

  rft_train_dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train[:20000]")\
        .map(reformat_rft)

  valid_dataset = concatenate_datasets([
      load_dataset("Asap7772/cog_behav_all_strategies", split="test")\
        .rename_columns({"query": "prompt"})\
        .map(reformat_val),
      synth_valid_dataset
  ])

  TARGET_BANK = set()
  for row in valid_dataset:
    TARGET_BANK.add(prompt_to_entry(row["prompt"]))

  def remove_data_leaks(row):
    entry = prompt_to_entry(row['prompt'])
    return entry not in TARGET_BANK

  sft_train_dataset = sft_train_dataset.filter(remove_data_leaks)

In [None]:
print(sft_train_dataset)
print(rft_train_dataset)
print(valid_dataset)

In [None]:
# def round_up_to_power_of_2(n):
#     # -- ChatGPT --
#     if n < 1:
#         return 1
#     # If n is already a power of two, return n
#     if (n & (n - 1)) == 0:
#         return n
#     power = 1
#     while power < n:
#         power <<= 1  # multiply power by 2
#     return power

# sequence_length = min(round_up_to_power_of_2(avg_toks_per_example), 2048)

In [None]:
# sequence_length

In [None]:
# model = AutoModelForCausalLM.from_pretrained(
#     PRETRAINED_MODEL_NAME,
#     # quantization_config=bnb_config,
#     torch_dtype="auto",
#     device_map="auto",
#     attn_implementation="flash_attention_2"
# )

# sft_training_args = SFTConfig(
#   # Memory Parameters
#   max_length=1024,
#   bf16=True,
#   gradient_checkpointing=True,
#   num_train_epochs=10,
#   save_strategy="epoch",
#   # Throughput
#   packing=True,
#   # Logging
#   logging_steps=20,
#   output_dir="./sft_checkpoints",
#   report_to="none",
# )

# sft_trainer = SFTTrainer(
#     model,
#     train_dataset=sft_train_dataset,
#     args=sft_training_args
# )
# sft_trainer.train()

In [None]:
# import os
# for ckpt in os.listdir("/content/sft_checkpoints/"):
#   try:
#     step = int(ckpt.split('-')[-1])
#     dst = f"/content/drive/MyDrive/cs224r/sft/countdown-joint/qwen2-0.5b-checkpoint-{step}"
#     if not os.path.exists(dst):
#       shutil.copytree(f"/content/sft_checkpoints/checkpoint-{step}", dst)
#   except:
#     continue

In [None]:
# # Automatically disconnect and shut down the Colab runtime
# import os
# from google.colab import runtime

# # Option 1: Using the Colab `runtime` module (cleanest method)
# runtime.unassign()

# # Option 2: Force shutdown (in case unassign doesn't work)
# os._exit(0)

In [None]:
# import os
# max_step = -1
# max_ckpt = None
# for ckpt in os.listdir("/content/sft_checkpoints/"):
#   try:
#     step = int(ckpt.split('-')[-1])
#     if step > max_step:
#       max_step = step
#       max_ckpt = ckpt
#   except:
#     continue

# if max_ckpt is not None:
#   dst = f"/content/drive/MyDrive/cs224r/sft/countdown/qwen2-0.5b-checkpoint-{max_step}"
#   if not os.path.exists(dst):
#     shutil.copytree(f"/content/sft_checkpoints/{max_ckpt}", dst)

In [None]:
# # Automatically disconnect and shut down the Colab runtime
# import os
# from google.colab import runtime

# # Option 1: Using the Colab `runtime` module (cleanest method)
# runtime.unassign()

# # Option 2: Force shutdown (in case unassign doesn't work)
# os._exit(0)

In [None]:
import os
import re
import time

from sympy import sympify, SympifyError

os.environ["PROFILE_REWARDS"] = "false"

def reward_accuracy(prompts, completions, verification_info, **kwargs):
  """
  Reward via string-equals-check on boxed answer vs ground truth. Vulnerable to
  edge cases where model outputs multiple boxed answers (we would get the first)
  or if the boxed answer is not at the end of the completion and other
  curly brackets come later on due regex greedy match.
  """
  start_time = time.time()
  rewards = []
  for idx, (completion, verification_info) in enumerate(zip(completions, verification_info)):
    completion_match = re.search(r'\\boxed\{(.*)\}', completion)
    reference_match = re.match(r".*ground_truth':[\s]*'(.*)'}", verification_info)
    if completion_match is None or reference_match is None:
      rewards.append(0)
    else:
      candidate, reference = completion_match.group(1).strip(), reference_match.group(1).strip()
      is_correct = candidate == reference
      rewards.append(int(is_correct) * 1.0)
  if os.getenv("PROFILE_REWARDS", '').lower() == "true":
    print(f"Accuracy duration: {time.time() - start_time}")
  return rewards


def len_penalty(completions, **kwargs):
    """Reward function that assigns higher scores to longer completions (in terms of token count)."""
    # Documentation bug:
    # TRL passes in `completion_ids` - https://github.com/huggingface/trl/blob/29401e790efd232a1bb14a247a7875fc789f0a7b/trl/trainer/grpo_trainer.py#L1167
    # But GRPO docs suggest `completions_ids` - https://huggingface.co/docs/trl/main/en/grpo_trainer#using-a-custom-reward-function
    start_time = time.time()
    # prev: 100000
    rewards = [-float(len(completion)) / 1500 for completion in completions]
    if os.getenv("PROFILE_REWARDS", '').lower() == "true":
      print(f"Length penalty duration: {time.time() - start_time}")
    return rewards


def format_reward(completions, **kwargs):
  rewards = []
  for completion in completions:
    completion = "<think>" + completion
    opened_think = completion.count("<think>") == 1
    closed_think = completion.count("</think>") == 1
    opened_answer = completion.count("<answer>") == 1
    closed_answer = completion.count("</answer>") == 1
    if not (opened_think and closed_think and opened_answer and closed_answer):
      rewards.append(0.0)
    else:
      rewards.append(
          1.0 if re.match(
              r"[\s\S]*<think>[\s\S]*</think>[\s\S]*<answer>[\s\S]*</answer>[\s\S]*",
              completion
              ) is not None else 0.0
          )
  return rewards


def formula_match(prompts, completions, target, nums, **kwargs):
  rewards = []
  for prompt, completion, target, num_list in zip(prompts, completions, target, nums):
    candidate_nums = [int(i) for i in re.findall(r'\d+', completion)]
    answer = re.match(r"[\s\S]*<answer>(.*?)</answer>.*", completion)
    if set(candidate_nums) != set(num_list) or answer is None:
      rewards.append(0)
    else:
      try:
        candidate_target = sympify(answer.group(1))
        rewards.append(int(candidate_target == target))
      except (SympifyError, TypeError):
        rewards.append(0)
  return rewards


def formula_match_v2(prompts, completions, **kwargs):
  rewards = []
  for prompt, completion in zip(prompts, completions):
    target = int(
        re.match(r"[\s\S]*create an equation that equals (\d+).*", prompt).group(1)
    )
    ref_nums, _ = prompt_to_entry(prompt)
    answer = re.match(r"[\s\S]*<answer>(.*?)</answer>.*", completion)

    try:
      candidate_target = sympify(answer.group(1))
      rewards.append(int(candidate_target == target))
    except (SympifyError, TypeError, AttributeError):
      rewards.append(0)
  return rewards


def formula_match_v3(prompts, completions, **kwargs):
  rewards = []
  for prompt, completion in zip(prompts, completions):
    target = int(
        re.match(r"[\s\S]*create an equation that equals (\d+).*", prompt).group(1)
    )
    ref_nums, _ = prompt_to_entry(prompt)
    answer = re.match(r"[\s\S]*<answer>(.*?)</answer>.*", completion)

    try:
      # Did we use duplicate numbers when we shouldn't have,
      # or hallucinate numbers to make the numbers work?
      cand_nums = [int(i) for i in re.findall(r"\d+", answer.group(1))]
      ref_counts = {i: ref_nums.count(i) for i in ref_nums}
      cand_counts = {i: cand_nums.count(i) for i in cand_nums}

      if any([cand_counts[i] > ref_counts.get(i, 0) for i in cand_counts]):
        rewards.append(0.0)
      else:
        candidate_target = sympify(answer.group(1))
        rewards.append(float(candidate_target == target))
    except (SympifyError, TypeError, AttributeError):
      rewards.append(0.0)
  return rewards

In [None]:
valid_dataset[0]

In [None]:
# Math
prompts = ["Do some math, eh?", "What is math?"]
completions = ["\\boxed{x = 3}", "\\boxed{712}"]
reference_answers = ["{'ground_truth': 'x = 3'}", "{'ground_truth': '711'}"]
completion_ids = [[6303, 13], [304, 279, 12884, 13]]
print(reward_accuracy(prompts=prompts, completions=completions, completion_ids=completion_ids, verification_info=reference_answers))
print(len_penalty(prompts=prompts, completions=completions, completion_ids=completion_ids, verification_info=reference_answers))

In [None]:
# Countdown
completions = ["Let me think about that ... </think><answer>12 + 40 / 10</answer> Was that okay?",
               "Some stuff <answer>3 + 4 + 5</answer> How'd I do?",
               "This one on the other hand has me stumped"]
targets = [16, 13, 4]
nums = [[12, 40, 10], [3, 4, 5], [20, 50, 63]]
print(format_reward(completions))
print(formula_match(["p1", "p2", "p3"], completions, targets, nums))

In [None]:
comps = [
    "Human-readable solution:\n1. Since 57 is our target and we have two 76s, we need to find a way to get one of them into the same form as 57\n2. Looking at 76, if we divide 76 by something, we might get close to 57\n3. Let's try dividing 76 by 57:\n   - 76/57 ≈ 1.29\n4. Another approach: \n   - 76-57 = 19\n   - 76*57 would be too large\n5. Let's try another sequence:\n   - 76+57 = 133\n   - 133/76 ≈ 1.80\n6. New approach:\n   - 76-76 = 0\n   - 57+0 = 57\n7. Found it! We can do:\n   - First subtract 76 from 76: 76-76 = 0\n   - Then add that result to 57: 0+57 = 57\n</think>\n<answer>(76-76)+57</answer>",
    "<think>some chain of thought</think><answer>1 + 2 + 3</answer>",   # wrong answer, wrong nums
    "<think>some chain of thought</think><answer>28 * 2 + 1</answer>",  # right answer, wrong nums
    "<think>some chain of thought</think><answer>57 / 76 + 76</answer>", # wrong answer, right nums
    "<think>some chain of thought</think><answer>57 + 76 - 76 + 76 - 76</answer>", # right answer, right nums but duplicates
    ]
prompts = [
    "'A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [57, 76, 76], create an equation that equals 57. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.",
    "'A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [57, 76, 76], create an equation that equals 57. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.",
    "'A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [57, 76, 76], create an equation that equals 57. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.",
    "'A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [57, 76, 76], create an equation that equals 57. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.",
    "'A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [57, 76, 76], create an equation that equals 57. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.",
]
print(formula_match_v3(prompts, comps))

In [None]:
if DATASET_TYPE == "MATH":
  avg_prompt_toks, avg_completion_toks = 0, 0
  for example in rft_train_dataset:
    avg_prompt_toks += len(tokenizer(example['prompt'], add_special_tokens=True)['input_ids'])
    avg_completion_toks += len(tokenizer(example['gold_standard_solution'], add_special_tokens=True)['input_ids'])
  avg_prompt_toks = int(avg_prompt_toks / N_RFT_TRAIN_EXAMPLES)
  avg_completion_toks = int(avg_completion_toks / N_RFT_TRAIN_EXAMPLES)
  print(f"{avg_prompt_toks=}")
  print(f"{avg_completion_toks=}")

In [None]:
import wandb
wandb.init(mode="disabled")  # since for some reason report_to='none' is not enough for GRPO

In [None]:
import os
import shutil

if not os.path.exists("/content/local-sft-checkpoint"):
  shutil.copytree("/content/drive/MyDrive/cs224r/sft/countdown-joint/qwen2-0.5b-checkpoint-11050", "/content/local-sft-checkpoint")
    # shutil.copytree("/content/drive/MyDrive/cs224r/sft/countdown-joint/qwen2-0.5b-checkpoint-1105", "/content/local-sft-checkpoint")

In [None]:
#### <-- Moved this into custom trl fork; no longer necessary in colab -->

# ### FIX ANNOYING TRL <-> TRANSFORMERS BUGS ###
# # Fix bug in trl <-> transformers
# from transformers.trainer_utils import seed_worker
# from torch.utils.data import DataLoader
# from transformers.utils import is_datasets_available
# from transformers import TrainerCallback

# import json


# class LogCallback(TrainerCallback):
#     def __init__(self):
#       self.fpath = "/content/logs/latest_metrics.txt"
#       os.makedirs(os.path.dirname(self.fpath), exist_ok=True)

#     def on_log(self, args, state, control, logs=None, **kwargs):
#       with open(self.fpath, "a") as f:
#         if logs:
#           f.write(json.dumps(logs) + '\n')


# class HappyGRPOTrainer(GRPOTrainer):
#     def get_train_dataloader(self):
#         if self.train_dataset is None:
#             raise ValueError("Trainer: training requires a train_dataset.")

#         train_dataset = self.train_dataset
#         data_collator = self.data_collator
#         if is_datasets_available() and isinstance(train_dataset, Dataset):
#             train_dataset = self._remove_unused_columns(train_dataset, description="training")
#         else:
#             data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

#         dataloader_params = {
#             "batch_size": self._train_batch_size * self.args.gradient_accumulation_steps,  # < this is the change
#             "collate_fn": data_collator,
#             "num_workers": self.args.dataloader_num_workers,
#             "pin_memory": self.args.dataloader_pin_memory,
#             "persistent_workers": self.args.dataloader_persistent_workers,
#         }

#         if not isinstance(train_dataset, torch.utils.data.IterableDataset):
#             dataloader_params["sampler"] = self._get_train_sampler()
#             dataloader_params["drop_last"] = self.args.dataloader_drop_last
#             dataloader_params["worker_init_fn"] = \
#               partial(seed_worker,
#                       num_workers=rft_training_args.dataloader_num_workers,
#                       rank=rft_training_args.process_index)
#             dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

#         return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

# metric_file_logger = LogCallback()

In [None]:
"""
Qwen2.5 0.5B params:

rft_training_args = GRPOConfig(
    # Memory
    bf16=True,
    per_device_train_batch_size=24,
    per_device_eval_batch_size=0,
    num_generations=4,
    gradient_checkpointing=True,
    max_prompt_length=256,
    max_completion_length=512,
    gradient_accumulation_steps=4,
    # Throughput
    use_vllm=False,
    # vllm_mode="colocate",  # single-gpu setting
    num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = -1,
    dataloader_num_workers=8,
    dataloader_prefetch_factor=2,
    dataloader_pin_memory=True,
    dataloader_persistent_workers=True,
    # Logging
    output_dir="./rft-checkpoints",
    # save_strategy="epoch",
    # save_total_limit=2,
    report_to="none",
)
"""

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir ./logs

In [None]:
import json

with open("/content/drive/MyDrive/cs224r/rft/countdown-joint/batches_to_exclude_bs16.json", "r") as f:
  batches_to_exclude = {i for i, _ in json.load(f)}

In [None]:
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
# {'top_k': 20, 'top_p': 0.8, 'repetition_penalty': 1.1, 'bos_token_id': 151643}

RUN_NAME = "V31"

peft_config = LoraConfig(
  r=8,
  lora_alpha=32,
  lora_dropout=0.1,
  bias="none",
  task_type="CAUSAL_LM",
)

rft_training_args = GRPOConfig(
    # Memory
    bf16=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=0,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},  # it sets use_cache=False anyways, but to be explicit...
    gradient_accumulation_steps=1,
    # Throughput
    use_vllm=False,
    # vllm_mode="colocate",  # single-gpu setting
    dataloader_num_workers=8,
    dataloader_prefetch_factor=2,
    dataloader_pin_memory=True,
    dataloader_persistent_workers=True,
    # Logging
    output_dir="./rft-checkpoints",
    logging_dir=f"./logs/{RUN_NAME}",
    save_strategy="steps",
    save_steps=25,
    run_name=RUN_NAME,
    # save_total_limit=2,
    logging_steps=1,
    report_to=["tensorboard"],
    # Training
    reward_weights=[1.0, 0.1],
    # scale_rewards=False,
    # loss_type="dr_grpo",
    beta=0.001,
    num_train_epochs = 1, # Set to 1 for a full training run
    learning_rate=1e-5,
    # lr_scheduler_type="cosine",
    # warmup_ratio=0.03,
    max_steps = 400,
    # max_grad_norm=1.5,
    # epsilon_high=0.28,
    # Sampling
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=1024,
    mask_truncated_completions=True,
    temperature=1.01,
    top_p=0.8,
    top_k=100,
    bottom_p=0.9,
)

rft_trainer = GRPOTrainer(
    "/content/local-sft-checkpoint",
    # reward_funcs=[reward_accuracy, len_penalty],  # MATH
    reward_funcs=[formula_match_v3, format_reward],
    args=rft_training_args,
    train_dataset=rft_train_dataset,
    # callbacks=[metric_file_logger],
    peft_config=peft_config,
    optimizers=(None, None),
    # batches_to_exclude=batches_to_exclude,
    # optimizer_cls_and_kwargs=(bitsandbytes.optim.Lion8bit, {}),
)

In [None]:
# rft_trainer.model = torch.compile(rft_trainer.model, mode="default", fullgraph=True)

In [None]:
rft_trainer.train()

In [None]:
rft_trainer.generation_config.temperature

In [None]:
merged_model = rft_trainer.model.merge_and_unload()

In [None]:
merged_model.save_pretrained("/content/rft-merged-ckpt")

In [None]:
print(rft_trainer.state.log_history[1])

In [None]:
# with open("/content/rft-checkpoints/checkpoint-200/trainer_state.json", "r") as f:
#   metrics_collection = [step for step in json.load(f)['log_history'] if "rewards/formula_match_v3/mean" in step]

In [None]:
import matplotlib.pyplot as plt

metrics_collection = None
num_gen = 4  # rft_training_args.num_generations

if DATASET_TYPE == "MATH":
  metrics_collection = metrics_collection or [step for step in rft_trainer.state.log_history if "rewards/reward_accuracy/mean" in step]
  steps = [step['step'] for step in metrics_collection]
  accs = [step['rewards/reward_accuracy/mean'] for step in metrics_collection]
  accs_errs = [step['rewards/reward_accuracy/std'] / (num_gen ** 0.5) for step in metrics_collection]
  lens = [step['rewards/len_penalty/mean'] for step in metrics_collection]
  lens_errs = [step['rewards/len_penalty/std'] / (num_gen ** 0.5) for step in metrics_collection]
  rewards = [step['reward'] for step in metrics_collection]
elif DATASET_TYPE == "COUNTDOWN":
  metrics_collection = metrics_collection or [step for step in rft_trainer.state.log_history if "rewards/formula_match_v3/mean" in step]
  steps = [step['step'] for step in metrics_collection]
  accs = [step['rewards/formula_match_v3/mean'] for step in metrics_collection]
  accs_errs = [step['rewards/formula_match_v3/std'] / (num_gen ** 0.5) for step in metrics_collection]
  formats = [step['rewards/format_reward/mean'] for step in metrics_collection]
  formats_errs = [step['rewards/format_reward/std'] / (num_gen ** 0.5) for step in metrics_collection]
  # lens = [step['rewards/len_penalty/mean'] for step in metrics_collection]
  # lens_errs = [step['rewards/len_penalty/std'] / (num_gen ** 0.5) for step in metrics_collection]
  rewards = [step['reward'] for step in metrics_collection]
  kls = [step['kl'] for step in metrics_collection]
  grad_norms = [step['grad_norm'] for step in metrics_collection]


In [None]:
import numpy as np

plt.plot(steps, accs)
plt.fill_between(steps, np.array(accs) - np.array(accs_errs), np.array(accs) + np.array(accs_errs), color='blue', alpha=0.3)
plt.xlabel("Step")
plt.ylabel("[Countdown v3] Mean Accuracy Reward (n=4)")
plt.title("Accuracy over time for GRPO Qwen2 0.5B")

In [None]:
plt.plot(steps, formats)
plt.fill_between(steps, np.array(formats) - np.array(formats_errs), np.array(formats) + np.array(formats_errs), color='blue', alpha=0.3)
plt.xlabel("Step")
plt.ylabel("Mean Length Penalty (n=4)")
plt.title("Format reward over time for GRPO Qwen2 0.5B")

In [None]:
plt.plot(steps, kls)
plt.xlabel("Step")
plt.ylabel("[Countdown v3] KL (n=4)")
plt.title("KL || GRPO Qwen2 0.5B")

In [None]:
plt.plot(steps, grad_norms)
plt.xlabel("Step")
plt.ylabel("[Countdown v3] Gradient Norms")
plt.title("Gradient Norms over Time || GRPO Qwen2 0.5B")

In [None]:
plt.plot(steps, lens)
plt.fill_between(steps, np.array(lens) - np.array(lens_errs), np.array(lens) + np.array(lens_errs), color='blue', alpha=0.3)
plt.xlabel("Step")
plt.ylabel("Mean Length Penalty (n=4)")
plt.title("Length penalty over time for GRPO Qwen2.5 1.5B")

In [None]:
from accelerate import Accelerator

accelerator = Accelerator()
unwrapped_model = accelerator.unwrap_model(rft_trainer.model)

In [None]:
unwrapped_model.save_pretrained("/content/rft-checkpoints/checkpoint-5-n=1000")

In [None]:
step = 750
version = "v9"
dst = f"/content/drive/MyDrive/cs224r/rft/countdown-joint/{version}/checkpoint-{step}"
if not os.path.exists(dst):
  shutil.copytree(f"/content/rft-checkpoints/checkpoint-{step}", dst)

In [None]:
# import os
# for ckpt in os.listdir("/content/rft-checkpoints/"):
#   try:
#     step = int(ckpt.split('-')[-1])
#     dst = f"/content/drive/MyDrive/cs224r/rft/countdown-o4m-distil/qwen2-0.5b-checkpoint-{step}"
#     if not os.path.exists(dst):
#       shutil.copytree(f"/content/rft-checkpoints/checkpoint-{step}", dst)
#   except:
#     continue

== WITHOUT PEFT ==

bs8 / grad_ckpt False => no OOM; 45 mins

bs16 / grad_ckpt True => no OOM; 30 mins

<pre>
rft_training_args = GRPOConfig(
    # Memory
    bf16=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=0,
    num_generations=4,
    gradient_checkpointing=True,
    max_prompt_length=256,
    max_completion_length=512,
    gradient_accumulation_steps=4,
    # Throughput
    use_vllm=False,
    # vllm_mode="colocate",  # single-gpu setting
    num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = -1,
    dataloader_num_workers=8,
    dataloader_prefetch_factor=2,
    dataloader_pin_memory=True,
    dataloader_persistent_workers=True,
    # save_steps = 10000,
    # Logging
    output_dir="./rft-checkpoints",
    report_to="none",
)</pre>

setting per_device_train_batch_size=24 => no OOM; 25 mins

also setting gradient_accumulation_steps=2 => no OOM; 40 mins

In [None]:
shutil.copytree("/content/rft-checkpoints/checkpoint-5-n=1000",
                "/content/drive/MyDrive/cs224r/rft/qwen2-1.5b-checkpoint-5_n=500")

In [None]:
import os
import shutil

if not os.path.exists('/content/sft_checkpoints'):
  shutil.copytree("/content/drive/MyDrive/cs224r/sft/countdown-joint/", "/content/sft_checkpoints")

# Eval

In [None]:
shutil.copytree("/content/drive/MyDrive/cs224r/rft/countdown-joint/v5/checkpoint-200",
                "/content/rft-checkpoints-v5/checkpoint-200")

In [None]:
import gc
import torch

from functools import partial
from transformers import AutoTokenizer, AutoModelForCausalLM, Qwen2ForCausalLM
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding, logging
from tqdm import tqdm

logging.set_verbosity_error()

# model_ckpt = "/content/local-sft-checkpoint"
# model_ckpt = "/content/rft-checkpoints/qwen2-0.5b-checkpoint-3093"
# model_ckpt = "/content/sft_checkpoints/checkpoint-250"
# model_ckpt = "/content/sft_checkpoints/qwen2-0.5b-checkpoint-11050"
# model_ckpt = "PRETRAINED_MODEL_NAME"

checkpoints = [
    # "/content/local-sft-checkpoint",
    "/content/rft-checkpoints/checkpoint-25",
    "/content/rft-checkpoints/checkpoint-50",
    "/content/rft-checkpoints/checkpoint-100",
    "/content/rft-checkpoints/checkpoint-150",
    "/content/rft-checkpoints/checkpoint-200",
    # "/content/rft-checkpoints/checkpoint-250",
    # "/content/rft-checkpoints/checkpoint-300",
    # "/content/rft-checkpoints/checkpoint-350",
    # "/content/rft-checkpoints/checkpoint-400",
]

In [None]:
num_return_sequences = 5

sampling_kwargs = {
    # "do_sample": False,  # deterministic; argmax greedy
    "do_sample": True,
    "top_k": 20,
    "top_p": 0.8,
    "temperature": 0.3,
    "repetition_penalty": 1.1,
    "num_return_sequences": num_return_sequences,
    # "bos_token_id": 151643,
}

def tokenize(example, tokenizer):
  return tokenizer(example["prompt"], padding=True, padding_side="left", return_tensors='pt', add_special_tokens=False)

# === Inference loop ===
def get_completions(model, tokenizer, data, device='cuda', bs=64):
  completions = [[] for _ in range(num_return_sequences)]
  max_new_tokens = 1024
  # val_loader = DataLoader(processed_valid_dataset, batch_size=256)
  val_loader = DataLoader(processed_valid_dataset, batch_size=bs)

  with torch.no_grad():
      for batch in tqdm(val_loader, position=0, leave=True):
          input_ids = batch["input_ids"].to(device)
          attention_mask = batch["attention_mask"].to(device)

          with torch.autocast("cuda", dtype=torch.bfloat16):
              generated_ids = model.generate(
                  input_ids=input_ids, attention_mask=attention_mask,
                  max_new_tokens=max_new_tokens,
                  **sampling_kwargs,
                  pad_token_id=tokenizer.eos_token_id
              )  # (batch_size x num_return_sequences, max_length)

              generated_ids = generated_ids.reshape(
                  input_ids.shape[0],  # true batch size; may be < bs if final
                  num_return_sequences,
                  generated_ids.shape[1],  # max tokens from this batch
              )

          # Only keep the generated suffix
          for idx in range(num_return_sequences):
            for inp, gen in zip(input_ids, generated_ids[:, idx, :]):
                gen_text = tokenizer.decode(gen[len(inp):], skip_special_tokens=True)
                completions[idx].append(gen_text)
  return completions

In [None]:
dataset_to_evaluate = valid_dataset
processed_valid_dataset = dataset_to_evaluate.map(partial(tokenize, tokenizer=tokenizer), batched=True)
processed_valid_dataset.set_format("torch", columns=["input_ids", "attention_mask"])

In [None]:
completion_collection = []
from peft import AutoPeftModelForCausalLM
for model_ckpt in checkpoints:
  gc.collect()
  model = AutoModelForCausalLM.from_pretrained(model_ckpt,
                                               torch_dtype="auto",  # bfloat16
                                               device_map="auto");
  model.config.use_cache = True
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  model = torch.compile(model.to(device))
  model.eval()

  completions_list = get_completions(model, tokenizer, processed_valid_dataset)

  if DATASET_TYPE == "MATH":
    reward_floats = [reward_accuracy(prompts=None, completions=completions, verification_info=valid_dataset['verification_info']) for completions in completions_list]
    flat_rewards = [it for coll in reward_floats for it in coll]
    print(f"{model_ckpt} accuracy: {sum(flat_rewards) / len(flat_rewards) * 100}")
  elif DATASET_TYPE == "COUNTDOWN":
    reward_floats = [formula_match_v3(prompts=dataset_to_evaluate["prompt"], completions=completions) for completions in completions_list]
    flat_rewards = [it for coll in reward_floats for it in coll]
    print(f"{model_ckpt} accuracy: {sum(flat_rewards) / len(flat_rewards) * 100}")

  completion_collection.append(completions_list)

In [None]:
completion_collection[0][0][1]

In [None]:
completion_collection[1][1]

In [None]:
completion_collection[2][1]

In [None]:
completion_collection[3][1]

In [None]:
completion_collection[4][1]

In [None]:
import json
with open("/content/drive/MyDrive/cs224r/completions/rft-v31.json", "w") as f:
  json.dump({"completions": completion_collection}, f)

In [None]:
# def tokenize(example):
#   return tokenizer(example["prompt"], padding="longest", truncation=True, return_tensors='pt')

# processed_valid_dataset = valid_dataset.map(tokenize, batched=True)
# processed_valid_dataset.set_format("torch", columns=["input_ids", "attention_mask"])

In [None]:
# # === Save completions ===
# if not os.path.exists("rft_completions.txt"):
#   with open("rft_completions.txt", "w", encoding="utf-8") as f:
#       for line in completions:
#           f.write(line.strip() + "\n")

In [None]:
# if DATASET_TYPE == "MATH":
#   rewards = reward_accuracy(prompts=None, completions=completions, verification_info=valid_dataset['verification_info'])
#   print(f"Accuracy: {sum(rewards) / len(rewards) * 100}")
# elif DATASET_TYPE == "COUNTDOWN":
#   rewards = formula_match_v3(prompts=valid_dataset["prompt"], completions=completions)
#   print(f"Accuracy: {sum(rewards) / len(rewards) * 100}")

In [None]:
# rewards_v2 = formula_match_v2(prompts=valid_dataset["prompt"], completions=completions)
# rewards_v3 = formula_match_v3(prompts=valid_dataset["prompt"], completions=completions)

# for p, c, r2, r3 in zip(valid_dataset["prompt"], completions, rewards_v2, rewards_v3):
#   if r2 != r3:
#     print(f"PROMPT: {p}\n\nCOMPLETION: {c}")

In [None]:
from google.colab import output
output.eval_js('new Audio("https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg").play()')

In [None]:
N = 1
completions[N], valid_dataset[N]

In [None]:
reward_accuracy(prompts=None, completions=[completions[5]], verification_info=valid_dataset['verification_info'][5:6])

In [None]:
valid_dataset[5]

In [None]:
valid_dataset[9]['prompt'], valid_dataset[9]['verification_info']

In [None]:
# {'top_k': 20, 'top_p': 0.8, 'repetition_penalty': 1.1, 'bos_token_id': 151643}

peft_config = LoraConfig(
  r=8,
  lora_alpha=32,
  lora_dropout=0.1,
  bias="none",
  task_type="CAUSAL_LM",
)

rft_training_args = GRPOConfig(
    # Memory
    bf16=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=0,
    num_generations=4,
    learning_rate=5e-7,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},  # it sets use_cache=False anyways, but to be explicit...
    max_prompt_length=256,
    max_completion_length=1024,
    gradient_accumulation_steps=1,
    # Throughput
    use_vllm=False,
    # vllm_mode="colocate",  # single-gpu setting
    num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = -1,
    dataloader_num_workers=8,
    dataloader_prefetch_factor=2,
    dataloader_pin_memory=True,
    dataloader_persistent_workers=True,
    # Logging
    output_dir="./rft-checkpoints",
    save_strategy="steps",
    save_steps=50,
    # save_total_limit=2,
    logging_steps=1,
    report_to="none",
    # Training
    reward_weights=[1.0, 0.1],
    beta=0.001,
)

rft_trainer = HappyGRPOTrainer(
    "/content/local-sft-checkpoint",
    # reward_funcs=[reward_accuracy, len_penalty],  # MATH
    reward_funcs=[formula_match_v3, format_reward],
    args=rft_training_args,
    train_dataset=rft_train_dataset,
    callbacks=[metric_file_logger],
    peft_config=peft_config,
)

In [None]:
rft_trainer.load

In [None]:
# Automatically disconnect and shut down the Colab runtime
import os
from google.colab import runtime

# Option 1: Using the Colab `runtime` module (cleanest method)
runtime.unassign()

# Option 2: Force shutdown (in case unassign doesn't work)
os._exit(0)