In [None]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

# Read the dataset

In [None]:
corpus = open('input.txt', 'r', encoding='utf-8').read()
print("length of dataset in characters: ", len(corpus))
print(f"\n{corpus[:200]}")

# Build a custom tokenizer with hugging face  tokenizer API

In [None]:
import re

word_freqs = defaultdict(int)
for word in re.findall(r'\b\w+\b|\n|[^\s\w]', corpus)
    word_freqs[word] += 1

alphabet = []
for word in word_freqs.keys():
    if word[0] not in alphabet:
        alphabet.append(word[0])
    for letter in word[1:]:
        if f"##{letter}" not in alphabet:
            alphabet.append(f"##{letter}")

vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + alphabet.copy()

splits = {
    word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
    for word in word_freqs.keys()}

def compute_pair_scores(splits):
    letter_freqs = defaultdict(int)
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            letter_freqs[split[0]] += freq
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            letter_freqs[split[i]] += freq
            pair_freqs[pair] += freq
        letter_freqs[split[-1]] += freq

    scores = { pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]]) for pair, freq in pair_freqs.items() }
    return scores

def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                merge = a + b[2:] if b.startswith("##") else a + b
                split = split[:i] + [merge] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

pair_scores = compute_pair_scores(splits)
best_pair = ""
max_score = None
for pair, score in pair_scores.items():
    if max_score is None or max_score < score:
        best_pair = pair
        max_score = score
vocab_size = 700
while len(vocab) < vocab_size:
    scores = compute_pair_scores(splits)
    best_pair, max_score = "", None
    for pair, score in scores.items():
        if max_score is None or max_score < score:
            best_pair = pair
            max_score = score
    splits = merge_pair(*best_pair, splits)
    new_token = ( best_pair[0] + best_pair[1][2:] if best_pair[1].startswith("##") else best_pair[0] + best_pair[1] )
    vocab.append(new_token)

def encode_word(word):
    tokens = []
    while len(word) > 0:
        i = len(word)
        while i > 0 and word[:i] not in vocab:
            i -= 1
        if i == 0:
            return ["[UNK]"]
        tokens.append(word[:i])
        word = word[i:]
        if len(word) > 0:
            word = f"##{word}"
    return tokens

def tokenize(text):
    pre_tokenized_text = re.findall(r'\b\w+\b|\n|[^\s\w]', text)
    encoded_words = [encode_word(word) for word in pre_tokenized_text]
    return sum(encoded_words, [])

In [None]:
# worked
# tokenizer = Tokenizer(models.WordPiece(unk_token="[UNK]"))
# tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
# tokenizer.normalizer = normalizers.Sequence(
#     [normalizers.NFD(), normalizers.StripAccents()])
# tokenizer.pre_tokenizer = pre_tokenizers.Sequence([pre_tokenizers.Split('', 'isolated'),
#                                                    pre_tokenizers.Punctuation()])

# tokenizer.pre_tokenizer.pre_tokenize_str("Let's test my \n pre-tokenizer.")
# special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]
# trainer = trainers.WordPieceTrainer(vocab_size=5000, special_tokens=special_tokens)
# tokenizer.train_from_iterator(text, trainer=trainer)

In [None]:
print(tokenizer.encode(text[:50]).ids) 
print(tokenizer.encode(text[:50]).tokens)

[23, 52, 61, 62, 63, 6, 20, 52, 63, 52, 69, 48, 57, 15, 5, 19, 48, 49, 58, 61, 48, 6, 66, 48, 6, 59, 61, 58, 46, 48, 48, 47, 6, 44, 57, 68, 6, 49, 64, 61, 63, 51, 48, 61, 11, 6, 51, 48, 44, 61]
['F', 'i', 'r', 's', 't', ' ', 'C', 'i', 't', 'i', 'z', 'e', 'n', ':', '\n', 'B', 'e', 'f', 'o', 'r', 'e', ' ', 'w', 'e', ' ', 'p', 'r', 'o', 'c', 'e', 'e', 'd', ' ', 'a', 'n', 'y', ' ', 'f', 'u', 'r', 't', 'h', 'e', 'r', ',', ' ', 'h', 'e', 'a', 'r']
