In [1]:
from transformers import PreTrainedTokenizerBase
from datasets import load_dataset
from pathlib import Path
import json

In [2]:
def process(
    tokenizer: PreTrainedTokenizerBase,
    tokenizer_name: str,
    output_dir: str,
):
    def tokenize_sample(sample):
        text = sample["text"]
        encoded = tokenizer.encode(text, return_tensors="pt").squeeze(0)
        return {"input_ids": encoded}

    wiki_en = load_dataset("wikipedia", "20220301.en", split="train")
    wiki_split = wiki_en.train_test_split(test_size=0.1)
    wiki_en_train = wiki_split["train"]
    wiki_en_validation = wiki_split["test"]

    save_location = Path(output_dir)
    save_meta_location = save_location / "meta.json"

    wiki_en_train = wiki_en_train.map(tokenize_sample)
    wiki_en_validation = wiki_en_validation.map(tokenize_sample)
    
    wiki_en_train = wiki_en_train.remove_columns(["id", "title", "text", "url"])
    wiki_en_validation = wiki_en_validation.remove_columns(["id", "title", "text", "url"])
    
    wiki_en_train.set_format("torch", columns=["input_ids"])
    wiki_en_validation.set_format("torch", columns=["input_ids"])
    
    wiki_en_train.save_to_disk(save_location / "train")
    wiki_en_validation.save_to_disk(save_location / "test")
    with open(save_meta_location, "w") as f:
        json.dump(
            {
                "tokenizer_name": tokenizer_name
            },
            f
        )

In [5]:
def get_data(data):
    return data.select([0])["input_ids"]

In [4]:
from transformers import AutoTokenizer

In [5]:
gpt2 = AutoTokenizer.from_pretrained("gpt2")

In [None]:
process(
    tokenizer=gpt2,
    tokenizer_name="gpt2",
    output_dir="data/wikipedia"
)

In [6]:
from datasets import load_from_disk

In [7]:
data = load_from_disk("data/wikipedia/test")

In [9]:
result = get_data(data)

In [10]:
result.shape

torch.Size([1, 282])

In [12]:
len(data)

645867

In [14]:
len(result.flatten())

282