In [60]:
import tiktoken
import torch
from datasets import load_dataset
import numpy as np
import math
import pickle
import math
from multiprocessing import Pool, cpu_count
import os
import time

In [61]:
# Parameters
block_size = 1024
save_path  = f'./dataset/wikipedia_ctx_{block_size}.dat'
num_proc   = cpu_count() # Number of processes to use when working with the dataset

In [62]:
tokenizer = tiktoken.get_encoding("p50k_base")
assert tokenizer.decode(tokenizer.encode("dogs are cool")) == "dogs are cool"
vocab_size = tokenizer.n_vocab

### Load dataset from HF

Using the `lucadiliello/english_wikipedia` dataset. This is about 6GB of text.

In [63]:
wikipedia_dataset = load_dataset("lucadiliello/english_wikipedia", split='train') # by default, the HF dataset puts everything into train.
wikipedia_dataset = wikipedia_dataset.remove_columns(['filename', 'source_domain', 'title', 'url'])
# wikipedia_dataset = wikipedia_dataset.select(range(12000))

In [64]:
# Output something random - making sure it works properly
wikipedia_dataset[int(torch.randint(0, len(wikipedia_dataset) - 1, (1,)))]['maintext']

"Epeli Kanakana\n\nRatu Epeli Kanakana (died 2010) was a Fijian chief. He held the title of Tui Suva, and was the traditional ruler of the area that includes the city of Suva, the nation's capital. The title of Tui Suva is only kept within the Naivutuvutu family of the Tokatoka Solia of Mataqali Vuanimocelolo of the Yavusa Vatuwaqa."

### Processing the Hugging Face dataset into a bunch of tensors

The following steps will transform all these english articles into chunks of tokens of length `block_size`.

Doing this upfront will speedup the training process as we won't need to tokenize and move each batch to the CPU as we go through our dataset. 

This will allow us to only work with tensors.

In [65]:
eot_token = tokenizer.eot_token

def tokenize(x):
    tokens = tokenizer.encode(x['maintext'])
    tokens.append(eot_token)
    return { 'tokens':  tokens, 'n_tokens': len(tokens) }

tokenized_dataset = wikipedia_dataset.map(tokenize, num_proc=num_proc, remove_columns=["maintext"])
del wikipedia_dataset
total_tokens = sum(tokenized_dataset['n_tokens'])
print(f"Number of tokens in dataset: {total_tokens / 1e9 :.2f}B")

Map (num_proc=12): 100%|██████████| 4184712/4184712 [03:57<00:00, 17590.84 examples/s]


Number of tokens in dataset: 2.28B


In [72]:
def chunk_dataset(dataset, n_chunks=12):
    n_rows        = len(dataset)
    chunk_size    = math.floor(n_rows / (n_chunks - 1))
    segments = [ dataset.select(range( i*chunk_size, (i+1)*chunk_size )) for i in range(n_chunks - 1) ]

    last_segment = dataset.select(range( (n_chunks - 1) * chunk_size, n_rows - 1 ))
    segments.append(last_segment)

    return segments

In [70]:
def process_chunk(dataset_chunk):
    n_tokens = sum(dataset_chunk['n_tokens'])
    arr = np.empty(n_tokens, dtype=np.uint16)
    ptr = 0
    for v in dataset_chunk:
        tokens = v['tokens']
        arr[ptr:ptr+len(tokens)] = tokens
        ptr += len(tokens)

    return arr

In [73]:
segments = chunk_dataset(tokenized_dataset, n_chunks=10)
all_tokens = np.empty(total_tokens, dtype=np.uint16)

ptr = 0
for i, chunk in enumerate(segments):
    print(f" == Processing chunk {i} ==")
    t0 = time.time()
    tokens = process_chunk(chunk)
    all_tokens[ptr:ptr+len(tokens)] = tokens
    ptr += len(tokens)

    del tokens

    t1 = time.time()
    print(f"==== Chunk {i} took: {t1 - t0 :.2f}s =====")


chunk_size = block_size + 1

n_excess_tokens = len(all_tokens) % chunk_size 
all_tokens = all_tokens[:len(all_tokens) - n_excess_tokens]

chunks_in_dataset = len(all_tokens) // chunk_size
new_shape = (chunks_in_dataset, chunk_size)

all_tokens = all_tokens.reshape(new_shape)

memmap = np.memmap(save_path, dtype=np.uint16, mode='w+', shape=all_tokens.shape)
memmap[:] = all_tokens
del memmap

 == Processing chunk 0 ==
==== Chunk 0 took: 116.30s =====
 == Processing chunk 1 ==
==== Chunk 1 took: 59.70s =====
 == Processing chunk 2 ==
==== Chunk 2 took: 45.88s =====
 == Processing chunk 3 ==
==== Chunk 3 took: 39.12s =====
 == Processing chunk 4 ==
==== Chunk 4 took: 38.96s =====
 == Processing chunk 5 ==
==== Chunk 5 took: 37.12s =====
 == Processing chunk 6 ==
==== Chunk 6 took: 34.74s =====
 == Processing chunk 7 ==
==== Chunk 7 took: 34.62s =====
 == Processing chunk 8 ==
==== Chunk 8 took: 32.06s =====
 == Processing chunk 9 ==
==== Chunk 9 took: 0.00s =====
