# WordPiece algorithm (top-down)

There are two versions of the WordPiece algorithm: Bottom-up and top-down. In both cases goal is the same: "Given a training corpus and a number of desired tokens D, the optimization problem is to select D wordpieces such that the resulting corpus is minimal in the number of wordpieces when segmented according to the chosen wordpiece model."

The top-down WordPiece generation algorithm takes in a set of (word, count) pairs and a threshold T, and returns a vocabulary V.

The algorithm is iterative. It is run for k iterations, where typically k = 4, but only the first two are really important. The third and fourth (and beyond) are just identical to the second. Note that each step of the binary search runs the algorithm from scratch for k iterations.

For this algorithm, we are following the steps as described in the [TensorFlow guide](https://www.tensorflow.org/text/guide/subwords_tokenizer#optional_the_algorithm).

In [21]:
import re
from collections import Counter, defaultdict
from typing import List, Tuple, Dict

import pytest
import ipytest
ipytest.autoconfig()


PREFIX = "##"
UNKNOWN = "<unk>"

In [22]:
text = ("Every morning we look for shells in the sand. I found fifteen big shells"
" last year. I put them in a special place in my room. This year I want to learn"
" to surf. It is hard to surf, but so much fun! My sister is a good surfer. She"
" says that she can teach me. I hope I can do it!")

text = re.sub("[.,!]", "", text)

In [23]:
words = Counter(text.lower().split())
print(words)

Counter({'i': 5, 'in': 3, 'to': 3, 'shells': 2, 'year': 2, 'a': 2, 'my': 2, 'surf': 2, 'it': 2, 'is': 2, 'she': 2, 'can': 2, 'every': 1, 'morning': 1, 'we': 1, 'look': 1, 'for': 1, 'the': 1, 'sand': 1, 'found': 1, 'fifteen': 1, 'big': 1, 'last': 1, 'put': 1, 'them': 1, 'special': 1, 'place': 1, 'room': 1, 'this': 1, 'want': 1, 'learn': 1, 'hard': 1, 'but': 1, 'so': 1, 'much': 1, 'fun': 1, 'sister': 1, 'good': 1, 'surfer': 1, 'says': 1, 'that': 1, 'teach': 1, 'me': 1, 'hope': 1, 'do': 1})


## Iteration 1 steps  

  1. Iterate over every word and count pair in the input, denoted as (w, c).
  2. For each word w, generate every substring, denoted as s. E.g., for the word human, we generate {h, hu, hum, huma, human, ##u, ##um, ##uma, ##uman, ##m, ##ma, ##man, #a, ##an, ##n}.

In [24]:
def generate_substrings(word:str, word_split_points:List[int]=None) -> List[str]:
    """Generate all word substrings. E.g., for the word human, we generate [h, 
    hu, hum, huma, human, ##u, ##um, ##uma, ##uman, ##m, ##ma, ##man, #a, ##an, 
    ##n]. 

    If split_points is set, returns only substrings that start at a split point.
    E.g., for word human and split_points [0,2], we generate ['h', 'hu', 'hum', 
    'huma', 'human', '##m', '##ma', '##man']
    
    Args:
        word: Word for which to generate substrings.
        word_split_points: List of indices 
        
    Returns:
        List of substrings. If a substring dos not include the start of a word, 
        prepends PREFIX.
    """
    split_points = word_split_points or range(len(word))
    # Slower but more readable approach
    # substrings = []
    # for i in split_points:
    #     for j in range(i + 1, len(word) + 1):
    #         substrings.append(f"##{word[i: j]}" if i > 0 else word[i: j])

    return [(PREFIX if i else "") + word[i: j] for i in split_points for j in range(i + 1, len(word) + 1)]


In [25]:
%%run_pytest[clean]

def test_generate_substrings_default():
    assert generate_substrings("human") == ["h", "hu", "hum", "huma", "human", "##u", "##um", "##uma", "##uman", "##m", "##ma", "##man", "##a", "##an", "##n"]

def test_generate_substrings():
    assert generate_substrings("human", [0,2]) == ["h", "hu", "hum", "huma", "human", "##m", "##ma", "##man"]

[32m.[0m[32m.[0m[32m                                                                                           [100%][0m
[32m[32m[1m2 passed[0m[32m in 0.01s[0m[0m


%%run_pytest[clean] and %%run_pytest are deprecated in favor of %%ipytest. %%ipytest will clean tests, evaluate the cell and then run pytest. To disable cleaning, configure ipytest with ipytest.config(clean=False).


  3. Maintain a substring-to-count hash map, and increment the count of each s by c. E.g., if we have (human, 113) and (humas, 3) in our input, the count of s = huma will be 113+3=116.

In [26]:
def get_substring_counts(word_tuples:List[Tuple[str, int]], split_points:List[List[int]]=None) -> Dict[str, int]:
    """Given a list of word-count pairs, returns substrings-count dictionary.
    
    Args:
        word_tuples: List of word-count pairs.
        split_points: List of indices at which to generate substrings for each
            word.

    Returns:
        Dictionary with substring-count pairs
    """
    substrings = defaultdict(int)
    for i, (word, count) in enumerate(word_tuples):
        for substring in generate_substrings(word, split_points[i] if split_points else None):
            substrings[substring] += count
    return substrings


In [27]:
%%run_pytest[clean]

@pytest.mark.parametrize("key,count", [("human", 113), ("humas", 3), ("huma", 116)])
def test_get_substring_counts(key,count):
    counts = get_substring_counts([("human", 113), ("humas", 3)])
    assert counts[key] == count

[32m.[0m[32m.[0m[32m.[0m[32m                                                                                          [100%][0m
[32m[32m[1m3 passed[0m[32m in 0.01s[0m[0m


%%run_pytest[clean] and %%run_pytest are deprecated in favor of %%ipytest. %%ipytest will clean tests, evaluate the cell and then run pytest. To disable cleaning, configure ipytest with ipytest.config(clean=False).


  4. Once we've collected the counts of every substring, iterate over the (s, c) pairs starting with the longest s first.
  5. Keep any s that has a c > T. E.g., if T = 100 and we have (pers, 231); (dogs, 259); (##rint; 76), then we would keep pers and dogs.
  6. When an s is kept, subtract off its count from all of its prefixes. This is the reason for sorting all of the s by length in step 4. This is a critical part of the algorithm, because otherwise words would be double counted. For example, let's say that we've kept human and we get to (huma, 116). We know that 113 of those 116 came from human, and 3 came from humas. However, now that human is in our vocabulary, we know we will never segment human into huma ##n. So once human has been kept, then huma only has an effective count of 3.

In [28]:
def get_vocabulary_for_iteration(substrings_counts: Dict[str, int], threshold: int=3) -> Dict[str, int]:
    """Given substrings counts, returns a dictionary of elements with count
    higher than treshold. Value associated with each key is a unique number 
    (index)
    
    Args:
        substrings_counts: Dictionary with substring-count pairs.
        threshold: Threshold at which to keep substrings.
        
    Returns:
        Dictionary of kept substrings as keys and incrementing index as value.
    """
    filtered_substrings = {}
    for substring in sorted(substrings_counts, key=lambda x: len(x.lstrip(PREFIX)), reverse=True):
        if substrings_counts[substring] < threshold:
            continue
        
        filtered_substrings[substring] = len(filtered_substrings)
        for i in range(3 if substring.startswith(PREFIX) else 1, len(substring)):
            substrings_counts[substring[:i]] -= substrings_counts[substring]
    return filtered_substrings

In [29]:
%%run_pytest[clean]

@pytest.fixture
def substrings():
    return get_substring_counts(words.items())

def test_get_vocabulary_for_iteration_len(substrings):
    vocabulary = get_vocabulary_for_iteration(substrings, 3)
    assert len(vocabulary) == 33

def test_get_vocabulary_for_iteration_3(substrings):
    vocabulary = get_vocabulary_for_iteration(substrings, 3)
    assert vocabulary["she"] == 1

def test_get_vocabulary_for_iteration_2(substrings):
    vocabulary = get_vocabulary_for_iteration(substrings, 2)
    assert vocabulary["year"] == 3

[32m.[0m[32m.[0m[32m.[0m[32m                                                                                          [100%][0m
[32m[32m[1m3 passed[0m[32m in 0.01s[0m[0m


%%run_pytest[clean] and %%run_pytest are deprecated in favor of %%ipytest. %%ipytest will clean tests, evaluate the cell and then run pytest. To disable cleaning, configure ipytest with ipytest.config(clean=False).


## Applying WordPiece

Once a WordPiece vocabulary has been generated, we need to be able to apply it to new data. The algorithm is a simple greedy longest-match-first application.

In [30]:
def tokenize(word:str, vocabulary: Dict[str, int]) -> List[str]:
    """Tokenize single word vocabulary. Returns <unk> token if word cannot be
    fully tokinzed due to missing tokens in the vocabulary.
    
    Args:
        word: Word to tokenize.
        vocabulary: Dictionary where keys are tokens and values are unique
        numbers associated with them.
        
    Returns:
        List of tokens.
    """
    word = word.lower()
    tokens = []
    while word.lstrip(PREFIX):
        for i in range(len(word)):
            if word[:len(word)-i] in vocabulary:
                tokens.append(word[:len(word)-i])
                word = PREFIX + word[len(word)-i:]
                break
        else:
            return [UNKNOWN]
    return tokens



In [31]:
%%run_pytest[clean]

@pytest.fixture
def vocabulary(substrings):
    return get_vocabulary_for_iteration(substrings, 3)

@pytest.mark.parametrize("word,tokens", [("shells", ["she", "##l", "##l", "##s"]), ("fishing", ["<unk>"])])
def test_tokenize(word, tokens, vocabulary):
    assert tokenize(word, vocabulary) == tokens

%%run_pytest[clean] and %%run_pytest are deprecated in favor of %%ipytest. %%ipytest will clean tests, evaluate the cell and then run pytest. To disable cleaning, configure ipytest with ipytest.config(clean=False).


[32m.[0m[32m.[0m[32m                                                                                           [100%][0m
[32m[32m[1m2 passed[0m[32m in 0.01s[0m[0m


# Iteration 2+

This algorithm will severely overgenerate word pieces. The reason is that we only subtract off counts of prefix tokens. Therefore, if we keep the word human, we will subtract off the count for h, hu, hu, huma, but not for ##u, ##um, ##uma, ##uman and so on. So we might generate both human and ##uman as word pieces, even though ##uman will never be applied.

So why not subtract off the counts for every substring, not just every prefix? Because then we could end up subtracting off the counts multiple times. Let's say that we're processing s of length 5 and we keep both (##denia, 129) and (##eniab, 137), where 65 of those counts came from the word undeniable. If we subtract off from every substring, we would subtract 65 from the substring ##enia twice, even though we should only subtract once. However, if we only subtract off from prefixes, it will correctly only be subtracted once.

To solve the overgeneration issue mentioned above, we perform multiple iterations of the algorithm.

 * Subsequent iterations are identical to the first, with one important distinction: In step 2, instead of considering every substring, we apply the WordPiece tokenization algorithm using the vocabulary from the previous iteration, and only consider substrings which start on a split point.

In [32]:
def get_split_points(word: str, vocabulary: Dict[str, int]) -> List[int]:
    """Returns list of split points for word.
    
    Args:
        word: Word for which to obtain tokens and infer split points.
        vocabulary: Dictionary with available tokens.
        
    Returns:
        List of indices where word was split into tokens. (It should start with
        0) 
    """
    tokens = tokenize(word, vocabulary)
    if tokens[0] == UNKNOWN:
        return list(range(len(word)))

    split_points = []
    cumulative_sum = 0
    for token in tokens:
        split_points.append(cumulative_sum)
        cumulative_sum += len(token.lstrip(PREFIX))
    return split_points

In [33]:
%%run_pytest[clean]

@pytest.mark.parametrize("word,splits", [("shells", [0, 3, 4, 5]), ("fishing", [0,1,2,3,4,5,6])])
def test_get_split_points(word, splits, vocabulary):
    assert get_split_points(word, vocabulary) == splits


%%run_pytest[clean] and %%run_pytest are deprecated in favor of %%ipytest. %%ipytest will clean tests, evaluate the cell and then run pytest. To disable cleaning, configure ipytest with ipytest.config(clean=False).


[32m.[0m[32m.[0m[32m                                                                                           [100%][0m
[32m[32m[1m2 passed[0m[32m in 0.01s[0m[0m


## Full implementation
For full implementations we do 4 iterations of the algorithm.

In [34]:
def get_vocabulary(word_tuples: List[Tuple[str, int]], threshold: int) -> Dict[str, int]:
    """
    Args:
        word_tuples: Iterable (e.g., list, generator) of word-count pairs.
        threshold: Threshold at which to keep substrings.

    Returns:
        Dictionary of kept substrings as keys and incrementing index as value.
    """
    vocabulary = {}
    for i in range(4):
        split_points = [get_split_points(word, vocabulary) for word,_ in word_tuples] if i else None
        substrings_counts = get_substring_counts(word_tuples, split_points=split_points)
        vocabulary = get_vocabulary_for_iteration(substrings_counts, threshold=threshold)
    return vocabulary


In [35]:
%%run_pytest[clean]

def test_get_vocabulary():
    vocabulary = get_vocabulary(words.items(), 2)
    assert tokenize("year", vocabulary) == ["year"]

%%run_pytest[clean] and %%run_pytest are deprecated in favor of %%ipytest. %%ipytest will clean tests, evaluate the cell and then run pytest. To disable cleaning, configure ipytest with ipytest.config(clean=False).


[32m.[0m[32m                                                                                            [100%][0m
[32m[32m[1m1 passed[0m[32m in 0.01s[0m[0m


In [36]:
from elasticsearch.client import Elasticsearch
from collections import Counter

es = Elasticsearch(timeout=120)
es.indices.put_settings(body={"max_result_window": 100000})
index = "trec9_index"

ids = [doc["_id"] for doc in es.search(index=index, query={"match_all" : {}}, _source=False, size= 100000)["hits"]["hits"]]

TREC9_WORD_COUNTS = Counter()
for doc_id in ids:
    for field in ["title", "body"]:
        terms = es.termvectors(index=index, id=doc_id, fields=field, term_statistics=True).get("term_vectors", {}).get(field, {}).get("terms", {})
        TREC9_WORD_COUNTS.update({key: value["term_freq"] for key, value in terms.items()})

print(f"Number of words in corpus: {len(TREC9_WORD_COUNTS)}")




Number of words in corpus: 77650


In [42]:
for threshold in [5, 100, 1000]:
    vocabulary = get_vocabulary(TREC9_WORD_COUNTS.items(), threshold)
    print(f"\nLength of vocabulary: {len(vocabulary)}\nTokens:")
    for word in ["shells", "fishing"]:
        print(f"\t{word}: {tokenize(word, vocabulary)}")


Length of vocabulary: 43715
Tokens:
	shells:['shells']
	fishing:['fish', '##ing']

Length of vocabulary: 10943
Tokens:
	shells:['she', '##ll', '##s']
	fishing:['fish', '##ing']

Length of vocabulary: 2950
Tokens:
	shells:['she', '##ll', '##s']
	fishing:['fis', '##hin', '##g']


In [38]:
tokenize("undeniable", vocabulary)

['un', '##den', '##ia', '##ble']

In [39]:
%%run_pytest[clean]

@pytest.fixture
def vocabulary():
    return get_vocabulary(TREC9_WORD_COUNTS.items(), 100)

@pytest.mark.parametrize("word,tokens", [("shells", ["she", "##ll", "##s"]), ("fishing", ["fish", "##ing"])])
def test_tokenize(word, tokens, vocabulary):
    assert tokenize(word, vocabulary) == tokens

%%run_pytest[clean] and %%run_pytest are deprecated in favor of %%ipytest. %%ipytest will clean tests, evaluate the cell and then run pytest. To disable cleaning, configure ipytest with ipytest.config(clean=False).


[32m.[0m[32m.[0m[32m                                                                                           [100%][0m
[32m[32m[1m2 passed[0m[32m in 9.65s[0m[0m
