In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
from datasets import load_dataset
import tqdm
from transformers import GPT2TokenizerFast
import re

In [2]:
data = load_dataset('Salesforce/wikitext', 'wikitext-103-raw-v1')
# data = load_dataset('tiny_shakespeare')
train = data['train']
validation = data['validation']
test = data['test']

In [3]:
bos_token = "<|BOS|>"

In [None]:

def data_clean(input: list[str], seq_len=129) -> str:
    ret = ""
    for line in input:
        if len(line) == 0:  continue
        # remove @'s surrounding some characters
        line = re.sub(r' @([.,\-])@ ', r'\1', line)
        # find titles of articles and add bos_token
        matches = re.match(r'^ = ?(.+?) =?\n', line)    # this finds all title and subsection text
        if matches != None:
            c = line.count('=')
            if c == 2:
                # start new article
                ret += " " + bos_token
        ret += line

    ret = ret.split(" ")
    chunks = []
    curr_chunk = []
    cur_len = 0
    
    for word in ret:
        if cur_len > seq_len:
            chunks.append(" ". join(curr_chunk))
            curr_chunk = [word]
            cur_len = 1

        else:
            curr_chunk.append(word)
            cur_len += 1

    return chunks

In [5]:

tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
tokenizer.add_special_tokens({"bos_token":bos_token})

1

In [6]:
train_join = data_clean(train['text'])
val_join = data_clean(validation['text'])
test_join = data_clean(test['text'])

In [None]:
train_tok = [tokenizer(chunk, max_length=128, truncation=True)['input_ids'] for chunk in tqdm.tqdm(train_join)]
val_tok = [tokenizer(chunk, max_length=128, truncation=True)['input_ids'] for chunk in val_join]
test_tok = [tokenizer(chunk, max_length=128, truncation=True)['input_ids'] for chunk in test_join]

  9%|▊         | 66093/771103 [00:21<03:47, 3102.72it/s]

In [None]:
torch.save(train_tok, 'data/train_data_token.pt')
torch.save(val_tok, 'data/val_data_token.pt')
torch.save(test_tok, 'data/test_data_token.pt')