In [1]:
from typing import Dict
from datasets import Dataset, load_dataset
import os
import re

dataset_path = "openai/gsm8k"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def preprocess(text):
    if text is None:
        return " "
    text = text.strip()
    text = re.sub(r"\s+", " ", text)
    return text

In [3]:
def process_cot_example(
    example: Dict,
    tokenizer,
):
    question = preprocess(example["question"])
    attempt = preprocess(example["answer"])

    answer_parts = attempt.split("####")
    thinking = answer_parts[0].strip()
    answer = answer_parts[1].strip() if len(answer_parts) > 1 else ""

    assistant_text = (
        "[THINK]\n"
        + thinking
        + "\n[/THINK]\n"
        + "\n[ANSWER]\n"
        + answer
        + "\n[/ANSWER]\n"
    )

    text = tokenizer.apply_chat_template(
        [
            {"role": "user", "content": question},
            {
                "role": "assistant",
                "content": assistant_text,
            },
        ],
        tokenize=False,
    )
    return dict(text=text)


def preprocess_dataset(
    dataset: Dataset,
    tokenizer,
    processing_function=process_cot_example,
):
    processed_dataset = dataset.map(
        lambda x: processing_function(x, tokenizer),
        batched=False,
        remove_columns=dataset.column_names,
        load_from_cache_file=False,
    )
    return processed_dataset

In [4]:
def load_training_dataset(dataset_path: str):
    if os.path.exists(dataset_path):
        return load_dataset(dataset_path, "socratic")
    else:
        # from hf hub
        return load_dataset(dataset_path, "socratic")["train"]

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

model_id = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

In [6]:
train_ds = load_training_dataset(dataset_path)
processed_train_ds = preprocess_dataset(train_ds, tokenizer, process_cot_example)

Map: 100%|██████████| 7473/7473 [00:01<00:00, 5422.83 examples/s]


In [7]:
processed_train_ds[0]

{'text': '<s>[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?[/INST] [THINK]\nHow many clips did Natalia sell in May? ** Natalia sold 48/2 = <<48/2=24>>24 clips in May. How many clips did Natalia sell altogether in April and May? ** Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n[/THINK]\n\n[ANSWER]\n72\n[/ANSWER]</s>'}

In [8]:
def check_token_stats(processed_dataset, tokenizer):
    max_tokens = 0
    max_idx = 0
    total_tokens = 0
    min_tokens = float("inf")
    count = 0

    for idx, example in enumerate(processed_dataset):
        # More memory efficient - just get the length without storing tokens
        length = len(tokenizer.encode(example["text"], truncation=False))

        total_tokens += length
        count += 1

        if length < min_tokens:
            min_tokens = length

        if length > max_tokens:
            max_tokens = length
            max_idx = idx

    return {
        "max_tokens": max_tokens,
        "min_tokens": min_tokens if min_tokens != float("inf") else 0,
        "avg_tokens": total_tokens / count if count > 0 else 0,
        "max_idx": max_idx,
    }

In [9]:
check_token_stats(processed_train_ds, tokenizer)

{'max_tokens': 693,
 'min_tokens': 107,
 'avg_tokens': 256.35835675097013,
 'max_idx': 2345}