In [1]:
import os
import json

from datasets import load_dataset, load_from_disk
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizerFast

In [2]:
# wiki = load_dataset("wikimedia/wikipedia", "20231101.simple", split="train", cache_dir="data")
# wiki.save_to_disk("wiki")

wiki = load_from_disk("wiki")

wiki = wiki.remove_columns([col for col in wiki.column_names if col != "text"])

d = wiki.train_test_split(test_size=0.1)

In [3]:
def dataset_to_text(dataset, output_filename="data.txt"):
    with open(output_filename, 'w') as f:
        for t in dataset["text"]:
            print(t, file=f)

dataset_to_text(d["train"], "train.txt")
dataset_to_text(d["test"], "test.txt")

In [5]:
special_tokens = [
    "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "<S>", "<T>"
    ]

file = ["train.txt"]
vocab_size = 30_522
max_length = 512
truncate_longer_samples = True
tokenizer = BertWordPieceTokenizer()
tokenizer.train(files=file, vocab_size=vocab_size, special_tokens=special_tokens)
tokenizer.enable_truncation(max_length=max_length)
model_path = "pretrained-bert"

if not os.path.exists(model_path):
    os.makedirs(model_path)

tokenizer.save_model(model_path)

with open(os.path.join(model_path, "config.json"), "w") as f:
    tokenizer_config = {
        "do_lower_case": True,
        "unk_token": "[UNK]",
        "sep_token": "[SEP]",
        "pad_token": "[PAD]",
        "cls_token": "[CLS]",
        "mask_token": "[MASK]",
        "model_max_length": max_length,
        "vocab_size": vocab_size,
        "max_len": max_length,
    }
    json.dump(tokenizer_config, f)

tokenizer = BertTokenizerFast.from_pretrained(model_path)

print(tokenizer("Helloworld!"))




{'input_ids': [2, 13692, 9944, 7, 3], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}


In [None]:
def encode_with_truncation(example):
    return tokenizer(example["text"], truncation=True, max_length=max_length, padding="max_length", return_special_tokens_mask=True)

def encode_without_truncation(example):
    return tokenizer(example["text"], return_special_tokens_mask=True)

encode = encode_with_truncation if truncate_longer_samples else encode_without_truncation

train_dataset = d["train"].map(encode, batched=True)
test_dataset = d["test"].map(encode, batched=True)

if truncate_longer_samples:
    train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
    test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
else:
    train_dataset.set_format(columns=["input_ids", "attention_mask", "special_tokens_mask"])
    test_dataset.set_format(columns=["input_ids", "attention_mask", "special_tokens_mask"])