In [None]:
import os

try:
  from GOOGLE_INTERNAL_PACKAGE_PATH.pyglib import gfile
  file_open = gfile.Open
  
  NOTEBOOK_ENV = "g3"
except Exception:
  NOTEBOOK_ENV = "git"

  from google.cloud import storage

  client = storage.Client()
  bucket = client.bucket("tunix")
  file_open = lambda path, mode="r": bucket.blob(path).open(mode)

import pandas as pd
from datasets import Dataset

if NOTEBOOK_ENV == "g3":
  DATA_PATH_PREFIX =  "/GOOGLE_INTERNAL_STOAGE_PATH/gg-d/home/qwix-dev/"
  MODEL_PATH_PREFIX = "/GOOGLE_INTERNAL_STOAGE_PATH/gg-d/home/qwix-dev/"
else:
  DATA_PATH_PREFIX = "gs://tunix/rl/data"
  MODEL_PATH_PREFIX = "gs://tunix/rl/models"

DEEPSCALER_DATA_PATH = os.path.join(DATA_PATH_PREFIX, "DeepScaleR-Preview-Dataset/deepscaler.json")
AIME_2024_DATA_PATH = os.path.join(DATA_PATH_PREFIX, "HuggingFaceH4/aime_2024/train-00000-of-00001.parquet")

def create_datasets(
    train_ds_path: str = DEEPSCALER_DATA_PATH,
    test_ds_path: str = AIME_2024_DATA_PATH
):
  def preprocess_fn(example):
    return {
        "question": example["problem"],
        "ground_truth": example["answer"],
        "data_source": "math",
    }

  with file_open(train_ds_path) as train_f, file_open(test_ds_path, 'rb') as test_f:
    train_df = pd.read_json(train_f)
    test_df = pd.read_parquet(test_f)

  train_ds = Dataset.from_pandas(train_df).map(preprocess_fn)
  test_ds = Dataset.from_pandas(test_df).map(preprocess_fn)

  return train_ds, test_ds

train_ds, test_ds = create_datasets()

for s in iter(train_ds):
  print(s)
  break

for s in iter(test_ds):
  print(s)
  break


In [None]:
import jax

from flax import nnx
import os

try:
  from etils import ecolab
  cm = ecolab.adhoc(
      source=ecolab.FROM_NOTEBOOK_OR_HEAD,
      reload='tunix',
      behavior='preferred',
      cell_autoreload=True,
  )
except:
  import contextlib
  cm = contextlib.nullcontext()

with cm:
  from tunix.models.qwen2 import params as params_lib
  from tunix.models.qwen2 import model as model_lib
  from tunix.generate import sampler as sampler_lib

  
MODEL_PATH = os.path.join(MODEL_PATH_PREFIX, "/DeepSeek-R1-Distill-Qwen-1.5B")

mesh = jax.make_mesh((1, 4), ('fsdp', 'tp'))
config = model_lib.ModelConfig.deepseek_r1_distill_qwen_1_5b()
model = params_lib.create_model_from_safe_tensors(MODEL_PATH, config, mesh)
# nnx.display(model)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

sampler = sampler_lib.Sampler(
    model,
    tokenizer,
    sampler_lib.CacheConfig(
        cache_size=8192, num_layers=model.config.num_layers, num_kv_heads=model.config.num_kv_heads, head_dim=model.config.head_dim
    ),
)



In [None]:
from pprint import pprint

q = next(iter(test_ds.select(range(1))))['question']

INSTRUCTION = "Let's think step by step, and put your final answer within \\boxed{}."
PROMPT = f"{q} {INSTRUCTION}"

messages = [
    {"role": "user", "content": PROMPT},
]

inputs = tokenizer.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=False,
	return_dict=False,
)
print(inputs)

out = sampler([inputs], max_generation_steps=1024, echo=True)
pprint(out.text)

In [None]:
SYSTEM_PROMPT = "Let's think step by step, and put your final answer within \\boxed{}."

TEMPLATE = """{question}{system_prompt}"""

def templatize(prompts, tokenizer):
  out = []
  for p in prompts:
    out.append(

        tokenizer.apply_chat_template(
            [
                {"role": "user", "content": SYSTEM_PROMPT},
                {"role": "user", "content": p},
            ],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
        )
    )
  return out

def generate(
    question, tokenizer, sampler, max_generation_steps, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
  """Given prompt, generates text."""

  if isinstance(question, str):
    question = [question]
  input_batch = templatize(question, tokenizer)

  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=max_generation_steps,
      temperature=temperature,
      top_k=top_k,
      top_p=top_p,
      echo=True,
      seed=seed if seed is not None else None,
  )

  output = out_data.text
  if isinstance(question, str):
    return output[0]
  return output

In [None]:
from tqdm.auto import tqdm
import re

match_numbers = re.compile(
    rf"\\boxed.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
)
def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx : right_brace_idx + 1]

    return retval


def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[: len(left)] == left
        assert s[-1] == "}"
        return s[len(left) : -1]
    except Exception:
        return None

def extract_boxed_answer(solution: str) -> str:
    """Extract the answer from inside a LaTeX \\boxed{} command"""
    solution = last_boxed_only_string(solution)
    solution = remove_boxed(solution)
    return solution

def extract_answer(passage: str) -> str:
    if "\\boxed" in passage:
        return extract_boxed_answer(passage)
    return None

match_answer = re.compile(
    rf"^[\s]{{0,}}"
    rf"\\boxed"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)

def evaluate(
    dataset,
    sampler,
    tokenizer,
    max_generation_steps=4096,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_passes=1,
    corr_lst=False,
    make_lst=False,
):
  """Computes accuracy and percentage of outputs matching the format."""

  response_lst = []
  corr = 0
  partially_corr = 0
  corr_format = 0
  total = 0

  for batch in tqdm(dataset):
    answers = batch["answer"]
    questions = batch["question"]

    multiple_call_responses = [[] for _ in range(len(questions))]
    for p in range(num_passes):
      responses = generate(
          questions, tokenizer, sampler, max_generation_steps, temperature, top_k, top_p, seed=p
      )
      print(responses)
      for idx, response in enumerate(responses):
        multiple_call_responses[idx].append(response)

    for question, multiple_call_response, answer in zip(
        questions, multiple_call_responses, answers
    ):
      # check answer
      corr_ctr_per_question = 0
      partially_corr_per_question = 0
      corr_format_per_question = 0
      for response in multiple_call_response:
        extracted_response = extract_answer(response)
        if extracted_response is None:
          extracted_response = "-1000000"
        try:
          if float(extracted_response.strip()) == float(answer.strip()):
            corr_ctr_per_question += 1

          ratio = float(extracted_response.strip()) / float(answer.strip())
          if ratio >= 0.9 and ratio <= 1.1:
            partially_corr_per_question += 1
        except:
          print("SKIPPED")

        # check answer generated
        if match_answer.search(response) is not None:
          corr_format_per_question += 1

        if (
            corr_ctr_per_question > 0
            and partially_corr_per_question > 0
            and corr_format_per_question > 0
        ):
          break

      if corr_ctr_per_question > 0:
        corr += 1
        if corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      else:
        if not corr_lst and make_lst:
          response_lst.append((question, answer, multiple_call_response))
      if partially_corr_per_question > 0:
        partially_corr += 1
      if corr_format_per_question > 0:
        corr_format += 1

      total += 1
      if total % 10 == 0:
        print(
            f"===> {corr=}, {total=}, {corr / total * 100=}, "
            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
        )

  to_return = (
      corr,
      total,
      corr / total * 100,
      partially_corr / total * 100,
      corr_format / total * 100,
  )
  if make_lst:
    return to_return, response_lst
  return to_return

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
sampler = sampler_lib.Sampler(
    model,
    tokenizer,
    sampler_lib.CacheConfig(
        cache_size=32768 + 2048, num_layers=model.config.num_layers, num_kv_heads=model.config.num_kv_heads, head_dim=model.config.head_dim
    ),
)

evaluate(test_ds, sampler, tokenizer, max_generation_steps=32768)