# WordPiece Tokenizer

## 🔹 What is WordPiece Tokenizer?

**WordPiece** is a **subword-based tokenizer**, originally developed for **Google's BERT**.
It aims to:

* Handle unknown words (OOV = out-of-vocabulary),
* Reduce vocabulary size,
* Preserve semantic meaning as much as possible.

---

## 🔹 Why WordPiece?

Traditional tokenizers (like whitespace or word-based) fail in:

* Handling rare/unknown words → `[UNK]`
* Needing a huge vocabulary to cover all possible words.

✅ **WordPiece** solves this by:

* Breaking words into **known subword units**.
* Leveraging a **data-driven** method to learn which subwords to merge.

---

## 🔹 How it works – Step by Step

### 1. **Initial Vocabulary**

* Start with a character-level vocabulary (`a` to `z`, digits, symbols, and special tokens like `[CLS]`, `[SEP]`, `[UNK]`, etc.)
* Each word in the training corpus is represented as characters:

  ```
  "hello" → ['h', 'e', 'l', 'l', 'o']
         → ['h', '##e', '##l', '##l', '##o']   # WordPiece adds '##' to mark continuation
  ```

---

### 2. **Build Vocabulary using Merge Rules**

WordPiece is similar to Byte-Pair Encoding (BPE), but uses a **scoring function**:

#### 🔹 Merge Rule Scoring Formula:

$$
\text{Score}(A, B) = \frac{\text{freq}(AB) \times |Vocab|}{\text{freq}(A) \times \text{freq}(B)}
$$

Where:

* `freq(AB)` = frequency of merged pair,
* `|Vocab|` = current vocab size,
* `freq(A)` and `freq(B)` = individual frequencies of A and B.

Merge the best scoring pair iteratively until the vocabulary reaches the desired size.

---

### 3. **Tokenization of New Sentences**

To tokenize any input string:

* Start from left to right.
* Use **longest match first** (greedy).
* If no match is found, use `[UNK]`.

Example:
Vocabulary: `["[CLS]", "[SEP]", "my", "name", "is", "saga", "##r"]`

Input: `"my name is sagar"`

→ Tokenized:

```python
["my", "name", "is", "saga", "##r"]
```

---

### 4. **Special Tokens**

* `[CLS]` → added at start
* `[SEP]` → added at end or between sentence pairs
* `[PAD]` → for padding
* `[UNK]` → unknown tokens

So final tokenized form becomes:

```
["[CLS]", "my", "name", "is", "saga", "##r", "[SEP]"]
```

---

## 🔹 Example from BERT Tokenizer

```python
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
print(tokenizer.tokenize("unhappiness"))

# Output:
# ['un', '##happiness']
```

BERT’s tokenizer knew that:

* `un` is a common prefix,
* `##happiness` is a known subword.

---

## 🔹 Pros and Cons

### ✅ Pros:

* Handles unknown words gracefully.
* Compact vocab (e.g., BERT base uses \~30,000 tokens).
* Learns meaningful subwords (e.g., "play", "##ing", "##er").

### ❌ Cons:

* Requires pretraining on huge data.
* Can break up rare or morphologically rich words awkwardly.

---

## 🔹 Summary Table

| Step                   | Action                             |
| ---------------------- | ---------------------------------- |
| Init vocab             | Character-level                    |
| Build word frequencies | From corpus                        |
| Score token pairs      | Using scoring function             |
| Merge best pairs       | Until vocab size is reached        |
| Tokenize new sentence  | Using longest-match subword lookup |
| Add special tokens     | \[CLS], \[SEP], etc.               |

---


## Custom Implementation of Wordpiece Tokenizer

In [1]:
import json
from collections import defaultdict, Counter

class CustomWordPieceTokenizer:
    def __init__(self, vocab_size=100, special_tokens=None, max_len=16):
        self.vocab_size = vocab_size
        self.vocab = {}
        self.vocab_inv = {}
        self.word_freq = Counter()
        self.word_tokens = {}
        self.merge_rules = []
        self.max_len = max_len

        self.special_tokens = special_tokens or ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']
        for i, token in enumerate(self.special_tokens):
            self.vocab[token] = i
        self.offset = len(self.special_tokens)

    def train(self, corpus):
        for sentence in corpus:
            for word in sentence.strip().split():
                self.word_freq[word] += 1

        vocab_set = set()
        for word in self.word_freq:
            chars = list(word)
            tokens = [chars[0]] + ['##' + c for c in chars[1:]]
            self.word_tokens[word] = tokens
            vocab_set.update(tokens)

        self.vocab.update({tok: i + self.offset for i, tok in enumerate(sorted(vocab_set))})

        while len(self.vocab) < self.vocab_size:
            pair_freq = defaultdict(int)
            token_freq = defaultdict(int)

            for word, tokens in self.word_tokens.items():
                freq = self.word_freq[word]
                for token in tokens:
                    token_freq[token] += freq
                for i in range(len(tokens) - 1):
                    pair = (tokens[i], tokens[i + 1])
                    pair_freq[pair] += freq

            if not pair_freq:
                break

            def score(pair):
                ab = pair_freq[pair]
                a = token_freq[pair[0]]
                b = token_freq[pair[1]]
                return (ab * len(self.vocab)) / (a * b + 1e-9)

            best_pair = max(pair_freq, key=score)
            new_token = best_pair[0] + best_pair[1].lstrip("##")

            self.vocab[new_token] = len(self.vocab)
            self.merge_rules.append(best_pair)

            for word in list(self.word_tokens):
                tokens = self.word_tokens[word]
                i = 0
                while i < len(tokens) - 1:
                    if (tokens[i], tokens[i + 1]) == best_pair:
                        tokens[i:i + 2] = [new_token]
                    else:
                        i += 1

    def save_vocab(self, path='vocab.json'):
        with open(path, 'w') as f:
            json.dump(self.vocab, f, indent=2)

    def load_vocab(self, path='vocab.json'):
        with open(path, 'r') as f:
            self.vocab = json.load(f)
        self.vocab_inv = {v: k for k, v in self.vocab.items()}

    def tokenize(self, sentence):
        tokens = []
        for word in sentence.strip().split():
            tokens.extend(self._wordpiece_tokenize(word))
        return tokens

    def _wordpiece_tokenize(self, word):
        chars = list(word)
        if len(chars) == 1:
            tokens = chars
        else:
            tokens = [chars[0]] + ['##' + c for c in chars[1:]]

        i = 0
        while i < len(tokens) - 1:
            new_token = tokens[i] + tokens[i + 1].lstrip("##")
            if new_token in self.vocab:
                tokens[i:i + 2] = [new_token]
            else:
                i += 1
        return [tok if tok in self.vocab else '[UNK]' for tok in tokens]

    def convert_tokens_to_ids(self, tokens):
        return [self.vocab.get(tok, self.vocab['[UNK]']) for tok in tokens]

    def convert_ids_to_tokens(self, ids):
        if not self.vocab_inv:
            self.vocab_inv = {v: k for k, v in self.vocab.items()}
        return [self.vocab_inv.get(i, '[UNK]') for i in ids]

    def __call__(self, sentence, padding=False, truncation=False, return_tensors=None):
        tokens = ['[CLS]'] + self.tokenize(sentence) + ['[SEP]']
        input_ids = self.convert_tokens_to_ids(tokens)

        if truncation and len(input_ids) > self.max_len:
            input_ids = input_ids[:self.max_len]

        attention_mask = [1] * len(input_ids)
        token_type_ids = [0] * len(input_ids)

        if padding and len(input_ids) < self.max_len:
            pad_len = self.max_len - len(input_ids)
            input_ids += [self.vocab['[PAD]']] * pad_len
            attention_mask += [0] * pad_len
            token_type_ids += [0] * pad_len

        output = {
            "input_ids": input_ids,
            "token_type_ids": token_type_ids,
            "attention_mask": attention_mask
        }

        if return_tensors == "pt":
            import torch
            for k in output:
                output[k] = torch.tensor(output[k]).unsqueeze(0)  # Add batch dim

        return output


In [3]:
# Train
corpus = [
    "Artificial Intelligence (AI) refers to the technology that allows machines and computers to replicate human intelligence. Enables systems to perform tasks that require human-like decision-making, such as learning from data, identifying patterns, making informed choices and solving complex problems. Improves continuously by utilizing methods like machine learning and deep learning. Used in healthcare for diagnosing diseases, finance for fraud detection, e-commerce for personalized recommendations and transportation for self-driving cars. It also powers virtual assistants like Siri and Alexa, chatbots for customer support and manufacturing robots that automate production processes.",
    "Machine Learning is a subset of artificial intelligence (AI) that focuses on building systems that can learn from and make decisions based on data. Instead of being explicitly programmed to perform a task, a machine learning model uses algorithms to identify patterns within data and improve its performance over time without human intervention.",
    "Generative AI refers to a type of artificial intelligence designed to create new content, whether it's text, images, music, or even video. Unlike traditional AI, which typically focuses on analyzing and classifying data, generative AI goes a step further by using patterns it has learned from large datasets to generate new, original outputs. Essentially, it creates rather than just recognizes."
]

tokenizer = CustomWordPieceTokenizer(vocab_size=512)
tokenizer.train(corpus)
tokenizer.save_vocab('vocab.json')  # Save vocab

# Load and test
tokenizer.load_vocab('vocab.json')
x = tokenizer.tokenize("hi my name is sagar")
print("Tokens:", x)

y = tokenizer.convert_tokens_to_ids(x)
print("Token IDs:", y)

z = tokenizer.convert_ids_to_tokens(y)
print("Back to Tokens:", z)


Tokens: ['h', '##i', 'm', '##y', 'n', '##a', '##m', '##e', 'i', '##s', 's', '##a', '##g', '##ar']
Token IDs: [53, 20, 57, 35, 58, 12, 23, 16, 54, 29, 62, 12, 18, 451]
Back to Tokens: ['h', '##i', 'm', '##y', 'n', '##a', '##m', '##e', 'i', '##s', 's', '##a', '##g', '##ar']


In [4]:
tokenizer = CustomWordPieceTokenizer(vocab_size=50)
tokenizer.load_vocab('vocab.json')

x1 = tokenizer("hi my name is sagar", padding=True, truncation=True, return_tensors='pt')
print("x1:", x1)

z1 = tokenizer.convert_ids_to_tokens(x1["input_ids"][0].tolist())
print("z1:", z1)


x1: {'input_ids': tensor([[  2,  53,  20,  57,  35,  58,  12,  23,  16,  54,  29,  62,  12,  18,
         451,   3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
z1: ['[CLS]', 'h', '##i', 'm', '##y', 'n', '##a', '##m', '##e', 'i', '##s', 's', '##a', '##g', '##ar', '[SEP]']
