In [None]:
from transformers import AutoTokenizer
from datasets import load_dataset, DatasetDict
from tqdm.notebook import tqdm

In [None]:
args = {}

args["template_model_name"] = "mistralai/Mistral-7B-v0.1"
args["sequence_length"] = 2048
args["stride"] = 64
args["additional_special_tokens"] = []
args["chat_template"] = None
args["dataset_path"] = "wikitext"
args["dataset_size"] = 0
args["dataset_split"] = 0.9
args["seed"] = 42
args["shuffle"] = False

args

In [None]:
# Load tokenizer
print(f"Loading the tokenizer from {args['template_model_name']}...")
tokenizer = AutoTokenizer.from_pretrained(args["template_model_name"])

# Set up the tokenizer
tokenizer.model_max_length = args["sequence_length"]

# Add assistant token to the tokenizer for the chat template because it is not in the vocabulary without a space in front of it!
tokenizer.add_tokens(["assistant"])
print(
    f"'assistant' token added to the tokenizer because it is not in the vocabulary without a space in front of it!"
)

# Add special tokens to the tokenizer
# Add "<|im_start|>", "<|im_end|>", "<|pause|>", "<|mem_start|>", "<|mem_end|>", etc.
additional_special_tokens = [
    "<|im_start|>",
    "<|im_end|>",
    "<|named_user|>",  # Named user. For future use. Example: "<|im_start|><|named_user|>Alice\n<Alice's message><|im_end|>"
    "<|named_assistant|>",  # Named assistant. For future use. Example: "<|im_start|><|named_assistant|>Assistant George\n<Assistant George's message><|im_end|>"
    "<|mem_start|>",
    "<|mem_end|>",  # Memory start and end tokens. For future use. Store hidden information in the context, e.g. "<|mem_start|>Alice's birthday is 12th May.<|mem_end|>"
    "<|pause|>",  # Pause token. For future use. See https://arxiv.org/abs/2310.02226.pdf Think before you speak: Training Language Models With Pause Tokens
]

# Add additional special tokens
if args["additional_special_tokens"]:
    additional_special_tokens += args["additional_special_tokens"]

# Add <|spare_1|>, <|spare_2|>, etc. to the tokenizer to make the vocab size a multiple of 8
for i in range(1, 8 - (len(tokenizer) + len(additional_special_tokens)) % 8 + 1):
    additional_special_tokens.append(f"<|spare_{i}|>")

if len(additional_special_tokens) > 0:
    tokenizer.add_special_tokens(
        {"additional_special_tokens": additional_special_tokens},
        replace_additional_special_tokens=False,
    )

if args["additional_special_tokens"]:
    print(f"Additional special tokens added to the tokenizer.")

    # Print the token IDs of the special tokens
    for token in args["additional_special_tokens"]:
        print(f"{token}: {tokenizer(token)}")

# Assert that the vocab size is a multiple of 8
assert (
    len(tokenizer)
) % 8 == 0, "The vocabulary size is not a multiple of 8. Fix the padding code, dumbass!"

# Set up the chat template
if args["chat_template"]:
    tokenizer.chat_template = args["chat_template"]

print(f"Tokeniser loaded with {len(tokenizer)} tokens in the vocabulary")

In [None]:
# Prepare the dataset
def prepare_dataset(
    dataset: DatasetDict,
    dataset_size: int,
    dataset_split: float,
    shuffle: bool = False,
    seed: int = 42,
) -> DatasetDict:
    print("Preparing the dataset...")
    prepared_dataset = None

    # Shuffle if required
    if shuffle:
        dataset = dataset.shuffle(seed=seed)

    # Select the first dataset_size examples from the training set
    if dataset_size > 0:
        print("Selecting", dataset_size, "examples from the dataset...")
        prepared_dataset = dataset["train"].select(range(dataset_size))
    else:
        dataset_size = len(dataset["train"])
        print("Using the entire dataset of size", dataset_size)
        prepared_dataset = dataset["train"]

    # Split the dataset into training and evaluation sets (dataset_split% for training, 1-dataset_split% for evaluation)
    print("Splitting the dataset into training and evaluation sets...")
    print("Training set size:", round(dataset_size * dataset_split))
    print("Evaluation set size:", dataset_size - round(dataset_size * dataset_split))
    prepared_dataset = prepared_dataset.train_test_split(
        test_size=1 - dataset_split, seed=seed, shuffle=shuffle
    )

    # Return the training and evaluation datasets
    return prepared_dataset

In [None]:
# Load the dataset
print(f"Loading the dataset from {args['dataset_path']}")
dataset = load_dataset(args["dataset_path"])

dataset

In [None]:
# Prepare the dataset
prepared_dataset = prepare_dataset(
    dataset=dataset,
    dataset_size=args["dataset_size"],
    dataset_split=args["dataset_split"],
    shuffle=args["shuffle"],
    seed=args["seed"],
)

In [None]:
# Function to tokenize a batch of texts
total_tokens = 0
def tokenize_function(examples):
    result = tokenizer(
        examples["text"],
        padding=False,
        truncation=False,
        return_overflowing_tokens=True,
    )
    print(result["input_ids"])
    exit()
    global total_tokens
    total_tokens += len(result["input_ids"])
    return result

# Tokenize the dataset
total_tokens = 0

print("Tokenizing the dataset...")
tokenized_dataset = prepared_dataset.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=dataset["train"].column_names,
)

print("Total tokens:", total_tokens)

In [None]:
# Tokenize the dataset
total_tokens = 0

print("Tokenizing the dataset...")
tokenized_dataset = prepared_dataset.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=dataset["train"].column_names,
)

print("Total tokens:", total_tokens)

In [None]:
tokenized_dataset

In [None]:
# Count the number of tokens in the dataset. Use a tqdm progress bar to show progress.
total_tokens = 0
for example in tqdm(tokenized_dataset["train"]):
    total_tokens += len(example["input_ids"])

In [None]:
# Count the number of tokens in the dataset. Use a tqdm progress bar to show progress.
total_tokens = sum(
    len(example["input_ids"]) for example in tqdm(tokenized_dataset["train"])
)

In [None]:
# Print the total number of tokens in the dataset in a human-readable format
print(
    f"Total number of tokens in the dataset: {total_tokens:,} ({total_tokens/1e9:.2f} billion tokens)"
)