In [None]:
import re
import os
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

In [None]:
train = True
split = "train" if train else "test"

ds = load_dataset("Salesforce/wikitext", "wikitext-103-v1", split=split)

tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2')
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def tokenize(txt):
    tokenized = tokenizer(
        txt['text'], truncation=True, padding='max_length', 
        max_length=512, return_tensors='pt')
    return tokenized

def len_filter(ex):
    ln = np.where(np.array(ex['input_ids']) == tokenizer.pad_token_id)[0]
    ln = ln[0] if len(ln) > 0 else 1000
    return ln >= 1

In [None]:
ds = ds.filter(lambda example: (len(example['text']) > 0) and 
                                (not re.match(r'( =)+.*?(= )+\n', example['text'])))

In [None]:
ds = ds.map(tokenize, batched=True)
ds = ds.filter(len_filter)
ds.set_format(type='torch', columns=['input_ids'])
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # To avoid hugginface error

print("Number of rows in %s dataset: %d" % (split, len(ds)))

In [None]:

loader = DataLoader(ds, batch_size=32, shuffle=train, pin_memory=True, num_workers=2)