In [None]:
!pip install datasets tiktoken

In [2]:
import os
import multiprocessing as mp
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import json
import tiktoken

In [3]:
import regex as re


class TrieNode:
    def __init__(self):
        self.id = None
        self.children = {}

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, token, id):
        node = self.root
        for char in token:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.id = id

    def search(self, text, start_pos):
        match_id = None
        pos = start_pos
        token_len = 0
        node = self.root
        while pos < len(text):
            char = text[pos]
            if char not in node.children:
                break
            node = node.children[char]
            if node.id is not None:
                match_id = node.id
                token_len = (pos - start_pos) + 1
            pos += 1
        return match_id, token_len

    def encode(self, text):
        pos = 0
        ids = []
        while pos < len(text):
            id, token_length = self.search(text, pos)
            ids.append(id)
            pos += token_length
        return ids


def bytes_to_unicode():
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


class LinearTokenizer():
    def __init__(self, vocab):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}

        self.vocab_encode = vocab
        self.vocab_decode = {v:k for k,v in vocab.items()}

        self.trie = Trie()
        for token, token_id in self.vocab_encode.items():
            self.trie.insert(token, token_id)

        # https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
        self.pattern = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s""")


    def encode(self, text, return_token_tuple=False):
        # TODO special handling of <|endoftext|> token
        pretokens = self.pattern.findall(text)
        pretokens = [''.join(self.byte_encoder[b] for b in pretoken.encode('utf-8')) for pretoken in pretokens]
        ids = []
        for pretoken in pretokens:
            if pretoken in self.vocab_encode:
                ids.append(self.vocab_encode[pretoken])
            else:
                ids.extend(self.trie.encode(pretoken))

        if return_token_tuple:
            return (ids, [self.vocab_decode[id] for id in ids])
        return ids

    def decode(self, ids):
        out = ""
        for id in ids:
            if not id in self.vocab_decode:
                raise Exception(f"Error decoding {id}")
            out += self.vocab_decode[id]
        return bytearray([self.byte_decoder[c] for c in out]).decode('utf-8', errors="replace")


In [None]:
class TTWrapper:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.eot = tokenizer._special_tokens['<|endoftext|>']

    def encode(self, text):
        return self.tokenizer.encode_ordinary(text)

    def decode(self, ids):
        return self.tokenizer.decode(ids)

tokenizer = TTWrapper(tiktoken.get_encoding("gpt2"))
eot = tokenizer.eot

In [None]:
with open("/content/drive/MyDrive/Colab Notebooks/linear/vocab.json", 'r', encoding='utf-8') as f:
    vocab = json.load(f)
tokenizer = LinearTokenizer(vocab)
eot = tokenizer.vocab_encode["<|endoftext|>"]

In [None]:
text = "The quick brown Fox jumps 1234 OVER the lazy Dog."
ids = tokenizer.encode(text)
decoded = tokenizer.decode(ids)

print(text)
print(ids)
print(decoded)
print(decoded == text)

The quick brown Fox jumps 1234 OVER the lazy Dog.
[464, 2068, 7586, 5426, 18045, 17031, 19, 28729, 262, 16931, 8532, 13]
The quick brown Fox jumps 1234 OVER the lazy Dog.
True


In [4]:
local_dir = "edu_fineweb10B"
remote_path = "HuggingFaceFW/fineweb-edu"
remote_name = "sample-10BT" # None
shard_size = int(1e8) # 100M tokens per shard, total of 100 shards

dataset = load_dataset(remote_path, name=remote_name, split="train")
print(len(dataset))

README.md:   0%|          | 0.00/23.3k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/1630 [00:00<?, ?it/s]

000_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

001_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

002_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

003_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

004_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

005_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

006_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

007_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

008_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

009_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

010_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

011_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

012_00000.parquet:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

013_00000.parquet:   0%|          | 0.00/541M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9672101 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/98 [00:00<?, ?it/s]

9672101


In [None]:
# only use part of dataset for experimental runs
dataset = dataset.select(range(len(dataset) // 10))

In [None]:
# From https://github.com/karpathy/build-nanogpt/blob/master/fineweb.py

DATA_CACHE_DIR = "/content/drive/MyDrive/Colab Notebooks/linear/content/data/"
os.makedirs(DATA_CACHE_DIR, exist_ok=True)

def tokenize(doc):
    # tokenizes a single document and returns a numpy array of uint16 tokens
    tokens = [eot] # the special <|endoftext|> token delimits all documents
    tokens.extend(tokenizer.encode(doc["text"]))
    tokens_np = np.array(tokens)
    # print(tokens_np.dtype)
    assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
    tokens_np_uint16 = tokens_np.astype(np.uint16)
    return tokens_np_uint16

def write_datafile(filename, tokens_np):
    np.save(filename, tokens_np)

total_token_count = 0
# tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)
nprocs = max(1, os.cpu_count()//2)
with mp.Pool(nprocs) as pool:
    shard_index = 0
    # preallocate buffer to hold current shard
    all_tokens_np = np.empty((shard_size,), dtype=np.uint16)
    token_count = 0
    progress_bar = None
    for tokens in pool.imap(tokenize, dataset, chunksize=16):
        # is there enough space in the current shard for the new tokens?
        if token_count + len(tokens) < shard_size:
            # simply append tokens to current shard
            all_tokens_np[token_count:token_count+len(tokens)] = tokens
            token_count += len(tokens)
            total_token_count += len(tokens)
            # update progress bar
            if progress_bar is None:
                progress_bar = tqdm(total=shard_size, unit="tokens", desc=f"Shard {shard_index}", mininterval=10)
            progress_bar.update(len(tokens))
        else:
            # write the current shard and start a new one
            split = "val" if shard_index == 0 else "train"
            filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
            # split the document into whatever fits in this shard; the remainder goes to next one
            remainder = shard_size - token_count
            progress_bar.update(remainder)
            all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
            write_datafile(filename, all_tokens_np)
            shard_index += 1
            progress_bar = None
            # populate the next shard with the leftovers of the current doc
            all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
            token_count = len(tokens)-remainder

    # write any remaining tokens as the last shard
    if token_count != 0:
        split = "val" if shard_index == 0 else "train"
        filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
        write_datafile(filename, all_tokens_np[:token_count])

print(total_token_count)

Shard 0:  81%|████████  | 80594992/100000000 [00:54<00:12, 1582882.99tokens/s]
Shard 0: 100%|██████████| 100000000/100000000 [01:05<00:00, 1533506.28tokens/s]

Shard 1:  22%|██▏       | 21719145/100000000 [00:10<00:36, 2163320.49tokens/s][A
Shard 1:  43%|████▎     | 43354338/100000000 [00:21<00:27, 2023499.20tokens/s][A
Shard 1:  64%|██████▍   | 64245487/100000000 [00:31<00:17, 2048524.29tokens/s][A
Shard 1: 100%|█████████▉| 99999680/100000000 [00:50<00:00, 1998826.24tokens/s]
Shard 2:  84%|████████▍ | 84124973/100000000 [00:41<00:07, 2018182.47tokens/s]
Shard 2: 100%|█████████▉| 99998827/100000000 [00:50<00:00, 1988172.78tokens/s]

Shard 3:  20%|██        | 20070951/100000000 [00:10<00:39, 2003809.45tokens/s][A
Shard 3:  40%|████      | 40109559/100000000 [00:20<00:30, 1994995.11tokens/s][A
Shard 3:  61%|██████    | 60733920/100000000 [00:30<00:19, 2022994.43tokens/s][A
Shard 3: 100%|█████████▉| 99998900/100000000 [00:50<00:00, 1979396.72tokens/s]
Shard 4:  83%|████████▎ | 82583

9933071906


\

In [None]:
%cd "/content/drive/MyDrive/Colab Notebooks/linear"
!pwd
!zip -r ./dataset-lin.zip "./content/data/"

In [15]:
pretoken_count = 9034884870 # splitting dataset by GPT2 regex, minimum possible amount of tokens
bpe_token_count = 9953457919 # 48min
lin_token_count = 9933071906 # 43min
lin_token_count / bpe_token_count

0.9979518662593544

In [29]:
print(f"Fraction lin/bpe token count = {lin_token_count / bpe_token_count :0.4f}")
print(f"Fraction of tokens over optimal encoding = {lin_token_count / pretoken_count :.4f}(lin), {bpe_token_count / pretoken_count :.4f}(bpe)")

Fraction lin/bpe token count = 0.9980
Fraction of tokens over optimal encoding = 1.0994(lin), 1.1017(bpe)
