In [None]:
%load_ext autoreload
%autoreload 2

# Analyze token size of dataset(s)

In [None]:
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from typing import Union

In [None]:
tokenizer_id: str = "deepseek-ai/deepseek-math-7b-rl"
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

In [None]:
def calculate_token_stats(
    dataset: Union[Dataset, str] = "hkust-nlp/dart-math-hard",
    tokenizer: Union[
        PreTrainedTokenizer, PreTrainedTokenizerFast, str
    ] = "deepseek-ai/deepseek-math-7b-rl",
) -> dict:
    if isinstance(dataset, str):
        dataset = load_dataset(dataset, split="train")
    if isinstance(tokenizer, str):
        tokenizer = AutoTokenizer.from_pretrained(tokenizer)

    def tokenize_batch(examples) -> dict[str, int]:
        query_tokens = tokenizer(examples["query"], truncation=False, padding=False)
        response_tokens = tokenizer(
            examples["response"], truncation=False, padding=False
        )

        return {
            "query_token_size": [len(tokens) for tokens in query_tokens["input_ids"]],
            "response_token_size": [
                len(tokens) for tokens in response_tokens["input_ids"]
            ],
        }

    # Apply the tokenization to the entire dataset
    token_sizes = dataset.map(
        tokenize_batch,
        batched=True,
        batch_size=1024,
        num_proc=16,
        remove_columns=dataset.column_names,
    )

    # Calculate average token sizes
    avg_query_size = sum(token_sizes["query_token_size"]) / len(
        token_sizes["query_token_size"]
    )
    avg_response_size = sum(token_sizes["response_token_size"]) / len(
        token_sizes["response_token_size"]
    )

    return {
        "average_query_token_size": avg_query_size,
        "average_response_token_size": avg_response_size,
        "total_query_tokens": sum(token_sizes["query_token_size"]),
        "total_response_tokens": sum(token_sizes["response_token_size"]),
    }

In [None]:
calculate_token_stats("hkust-nlp/dart-math-hard", tokenizer)

Map (num_proc=16):   0%|          | 0/585392 [00:00<?, ? examples/s]

{'average_query_token_size': 96.58649076174598,
 'average_response_token_size': 480.3031387514691,
 'total_query_tokens': 56540959,
 'total_response_tokens': 281165615}

In [None]:
calculate_token_stats("hkust-nlp/dart-math-uniform", tokenizer)

Map (num_proc=16):   0%|          | 0/590705 [00:00<?, ? examples/s]

{'average_query_token_size': 68.52043236471674,
 'average_response_token_size': 273.8289179878281,
 'total_query_tokens': 40475362,
 'total_response_tokens': 161752111}