## How many tokens are in my dataset?

In [None]:
import os
import sys

sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(os.path.abspath("__file__")), "../.."))
)

from datasets import DatasetDict, load_dataset
from transformers import AutoTokenizer
from src.shared.config import settings
import json
from functools import partial
from typing import Literal, Dict
from textwrap import dedent

HF_GEMMA_ID = "google/gemma-3-1b-pt"

In [25]:
def tokenizer_init():
    tokenizer = AutoTokenizer.from_pretrained(HF_GEMMA_ID)
    return tokenizer


def format_dataset_row(row, eos_token: str) -> Dict[Literal["text"], str]:
    row = dict(row)
    from_account = row.pop("from_account")
    formatted = dedent(
        f"""
        input: {json.dumps(row)}
        label: {f"{from_account}{eos_token}"}
        """
    ).strip()
    return {"text": formatted}


def training_dataset_init(tokenizer) -> DatasetDict:
    dataset = load_dataset(
        f"{settings.hf_user_name}/{settings.hf_dataset_repo_name}"
    ).shuffle(0)

    eos_token = tokenizer.eos_token
    format_for_train = partial(
        format_dataset_row,
        eos_token=eos_token,
    )

    dataset["train"] = dataset["train"].map(format_for_train)
    dataset["validation"] = dataset["test"].map(format_for_train)
    del dataset["test"]

    remove_columns = [
        "transaction_date",
        "description",
        "amount",
        "category",
        "category_source",
        "card",
        "day_of_week",
        "from_account",
        "text",
    ]

    dataset = dataset.map(
        lambda batch: tokenizer(batch["text"]),
        batched=True,
        remove_columns=remove_columns,
    )
    dataset.set_format("pt")
    return dataset

In [26]:
dataset = training_dataset_init(tokenizer_init())

token_counts = [
    len(t) for t in dataset["train"]["input_ids"] + dataset["validation"]["input_ids"]
]
n_tokens = sum(token_counts)

average_token_count = n_tokens / len(token_counts)

n_tokens, average_token_count

(93603, 97.80877742946709)