# Tokenization and Embeddings

Welcome to Topic 6! In this notebook, we'll explore how text is converted into numerical representations that transformers can process. We'll implement various tokenization strategies and understand embeddings in depth.

## Learning Objectives

By the end of this notebook, you will:
- Understand different tokenization strategies
- Implement BPE (Byte-Pair Encoding) from scratch
- Explore word, subword, and character-level tokenization
- Master embedding techniques
- Learn about positional and segment embeddings
- Build a complete tokenization pipeline

## Setup and Imports

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict, Counter
import re
import json
from typing import List, Dict, Tuple, Optional
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Plotting configuration
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

## 1. Tokenization Fundamentals

Let's start by understanding different levels of tokenization.

In [None]:
# Example text
text = "The quick brown fox jumps over the lazy dog. Tokenization is fascinating!"

# 1. Character-level tokenization
char_tokens = list(text)
print("Character-level tokenization:")
print(f"Number of tokens: {len(char_tokens)}")
print(f"First 20 tokens: {char_tokens[:20]}")
print()

# 2. Word-level tokenization (simple split)
word_tokens = text.split()
print("Word-level tokenization (simple):")
print(f"Number of tokens: {len(word_tokens)}")
print(f"Tokens: {word_tokens}")
print()

# 3. Word-level tokenization (with punctuation handling)
import re
word_tokens_punct = re.findall(r'\b\w+\b|[.!?]', text)
print("Word-level tokenization (with punctuation):")
print(f"Number of tokens: {len(word_tokens_punct)}")
print(f"Tokens: {word_tokens_punct}")
print()

# Visualization of tokenization levels
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Character-level histogram
char_counts = Counter(char_tokens)
axes[0].bar(range(len(char_counts)), list(char_counts.values()))
axes[0].set_title('Character Token Frequencies')
axes[0].set_xlabel('Character Index')
axes[0].set_ylabel('Frequency')

# Word-level histogram
word_counts = Counter(word_tokens_punct)
axes[1].bar(range(len(word_counts)), list(word_counts.values()))
axes[1].set_title('Word Token Frequencies')
axes[1].set_xlabel('Word Index')
axes[1].set_ylabel('Frequency')

# Vocabulary sizes comparison
vocab_sizes = [
    len(set(char_tokens)),
    len(set(word_tokens)),
    len(set(word_tokens_punct))
]
axes[2].bar(['Character', 'Word (simple)', 'Word (punct)'], vocab_sizes)
axes[2].set_title('Vocabulary Sizes')
axes[2].set_ylabel('Unique Tokens')

plt.tight_layout()
plt.show()

## 2. Building a Simple Tokenizer

Let's build a basic word-level tokenizer with vocabulary management.

In [None]:
class SimpleTokenizer:
    def __init__(self, vocab_size=None):
        self.vocab_size = vocab_size
        self.token_to_id = {}
        self.id_to_token = {}
        
        # Special tokens
        self.pad_token = '<PAD>'
        self.unk_token = '<UNK>'
        self.bos_token = '<BOS>'
        self.eos_token = '<EOS>'
        
        # Initialize special tokens
        special_tokens = [self.pad_token, self.unk_token, self.bos_token, self.eos_token]
        for i, token in enumerate(special_tokens):
            self.token_to_id[token] = i
            self.id_to_token[i] = token
            
    def build_vocab(self, texts: List[str]):
        """Build vocabulary from texts."""
        # Count token frequencies
        token_freq = Counter()
        for text in texts:
            tokens = self._tokenize(text)
            token_freq.update(tokens)
            
        # Sort by frequency
        sorted_tokens = sorted(token_freq.items(), key=lambda x: x[1], reverse=True)
        
        # Add to vocabulary
        current_id = len(self.token_to_id)
        for token, freq in sorted_tokens:
            if self.vocab_size and current_id >= self.vocab_size:
                break
            if token not in self.token_to_id:
                self.token_to_id[token] = current_id
                self.id_to_token[current_id] = token
                current_id += 1
                
        print(f"Vocabulary size: {len(self.token_to_id)}")
        
    def _tokenize(self, text: str) -> List[str]:
        """Simple word tokenization."""
        # Convert to lowercase and split
        tokens = re.findall(r'\b\w+\b|[.!?]', text.lower())
        return tokens
    
    def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """Convert text to token ids."""
        tokens = self._tokenize(text)
        
        if add_special_tokens:
            tokens = [self.bos_token] + tokens + [self.eos_token]
            
        # Convert to ids
        ids = []
        for token in tokens:
            if token in self.token_to_id:
                ids.append(self.token_to_id[token])
            else:
                ids.append(self.token_to_id[self.unk_token])
                
        return ids
    
    def decode(self, ids: List[int], skip_special_tokens: bool = True) -> str:
        """Convert token ids back to text."""
        tokens = []
        special_tokens = {self.pad_token, self.unk_token, self.bos_token, self.eos_token}
        
        for id in ids:
            if id in self.id_to_token:
                token = self.id_to_token[id]
                if skip_special_tokens and token in special_tokens:
                    continue
                tokens.append(token)
                
        return ' '.join(tokens)

# Test the tokenizer
sample_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Machine learning is amazing!",
    "Transformers revolutionized NLP.",
    "Tokenization is a crucial preprocessing step."
]

tokenizer = SimpleTokenizer(vocab_size=50)
tokenizer.build_vocab(sample_texts)

# Test encoding and decoding
test_text = "Machine learning is fascinating!"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)

print(f"\nOriginal text: {test_text}")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")

# Show vocabulary
print("\nFirst 20 tokens in vocabulary:")
for i in range(min(20, len(tokenizer.id_to_token))):
    print(f"{i}: {tokenizer.id_to_token[i]}")

## 3. Byte-Pair Encoding (BPE) Implementation

Now let's implement BPE, a popular subword tokenization algorithm.

In [None]:
class BPETokenizer:
    def __init__(self, vocab_size: int = 1000):
        self.vocab_size = vocab_size
        self.word_tokenizer = re.compile(r'\b\w+\b|[.!?]')
        self.vocab = {}
        self.merges = []
        
    def _get_word_frequencies(self, texts: List[str]) -> Dict[str, int]:
        """Get frequency of each word in the corpus."""
        word_freq = Counter()
        for text in texts:
            words = self.word_tokenizer.findall(text.lower())
            word_freq.update(words)
        return dict(word_freq)
    
    def _get_pair_frequencies(self, word_splits: Dict[Tuple, int]) -> Counter:
        """Count frequency of adjacent pairs."""
        pair_freq = Counter()
        for word_tuple, freq in word_splits.items():
            for i in range(len(word_tuple) - 1):
                pair = (word_tuple[i], word_tuple[i + 1])
                pair_freq[pair] += freq
        return pair_freq
    
    def _merge_pair(self, word_splits: Dict[Tuple, int], pair: Tuple[str, str]) -> Dict[Tuple, int]:
        """Merge the most frequent pair."""
        new_word_splits = {}
        for word_tuple, freq in word_splits.items():
            new_word = []
            i = 0
            while i < len(word_tuple):
                if i < len(word_tuple) - 1 and word_tuple[i] == pair[0] and word_tuple[i + 1] == pair[1]:
                    new_word.append(pair[0] + pair[1])
                    i += 2
                else:
                    new_word.append(word_tuple[i])
                    i += 1
            new_word_splits[tuple(new_word)] = freq
        return new_word_splits
    
    def train(self, texts: List[str]):
        """Train BPE on texts."""
        # Get word frequencies
        word_freq = self._get_word_frequencies(texts)
        
        # Initialize with character-level splits
        word_splits = {}
        for word, freq in word_freq.items():
            word_tuple = tuple(word) + ('</w>',)  # Add end-of-word token
            word_splits[word_tuple] = freq
            
        # Initialize vocabulary with characters
        vocab_id = 0
        for word_tuple in word_splits:
            for char in word_tuple:
                if char not in self.vocab:
                    self.vocab[char] = vocab_id
                    vocab_id += 1
                    
        print(f"Initial vocabulary size: {len(self.vocab)}")
        
        # Perform merges
        progress_bar = tqdm(total=self.vocab_size - len(self.vocab), desc="BPE Training")
        
        while len(self.vocab) < self.vocab_size:
            # Get pair frequencies
            pair_freq = self._get_pair_frequencies(word_splits)
            
            if not pair_freq:
                break
                
            # Find most frequent pair
            most_frequent_pair = max(pair_freq, key=pair_freq.get)
            
            # Merge the pair
            word_splits = self._merge_pair(word_splits, most_frequent_pair)
            
            # Add to vocabulary
            merged = most_frequent_pair[0] + most_frequent_pair[1]
            if merged not in self.vocab:
                self.vocab[merged] = vocab_id
                vocab_id += 1
                self.merges.append(most_frequent_pair)
                progress_bar.update(1)
                
        progress_bar.close()
        print(f"Final vocabulary size: {len(self.vocab)}")
        
    def tokenize(self, text: str) -> List[str]:
        """Tokenize text using learned BPE."""
        words = self.word_tokenizer.findall(text.lower())
        tokens = []
        
        for word in words:
            word_tokens = list(word) + ['</w>']
            
            # Apply merges
            for merge in self.merges:
                i = 0
                while i < len(word_tokens) - 1:
                    if word_tokens[i] == merge[0] and word_tokens[i + 1] == merge[1]:
                        word_tokens = word_tokens[:i] + [merge[0] + merge[1]] + word_tokens[i + 2:]
                    else:
                        i += 1
                        
            tokens.extend(word_tokens)
            
        return tokens

# Train BPE tokenizer
corpus = [
    "The quick brown fox jumps over the lazy dog.",
    "Machine learning algorithms are powerful.",
    "Natural language processing is fascinating.",
    "Transformers have revolutionized NLP.",
    "Tokenization is important for text processing.",
    "Subword tokenization helps with rare words.",
    "BPE is a popular tokenization method.",
    "Neural networks learn representations."
]

bpe_tokenizer = BPETokenizer(vocab_size=100)
bpe_tokenizer.train(corpus)

# Test BPE tokenization
test_sentences = [
    "Machine learning is amazing.",
    "Tokenization helps process text.",
    "Unknown words are handled well."
]

print("\nBPE Tokenization Examples:")
for sentence in test_sentences:
    tokens = bpe_tokenizer.tokenize(sentence)
    print(f"\nText: {sentence}")
    print(f"Tokens: {tokens}")
    
# Show some learned merges
print("\nFirst 20 BPE merges:")
for i, merge in enumerate(bpe_tokenizer.merges[:20]):
    print(f"{i+1}: {merge[0]} + {merge[1]} = {merge[0] + merge[1]}")

## 4. Word Embeddings

Now let's explore how tokens are converted to dense vector representations.

In [None]:
class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, padding_idx: int = 0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.embedding_dim = embedding_dim
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize embedding weights."""
        # Xavier uniform initialization
        nn.init.xavier_uniform_(self.embedding.weight)
        
        # Set padding embedding to zero
        if self.embedding.padding_idx is not None:
            with torch.no_grad():
                self.embedding.weight[self.embedding.padding_idx].fill_(0)
                
    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embedding(input_ids)

# Create embedding layer
vocab_size = 1000
embedding_dim = 128
embedding_layer = EmbeddingLayer(vocab_size, embedding_dim)

# Test embeddings
token_ids = torch.tensor([[1, 5, 10, 15, 0, 0],  # 0 is padding
                         [2, 7, 12, 0, 0, 0]])
embeddings = embedding_layer(token_ids)

print(f"Input shape: {token_ids.shape}")
print(f"Embedding shape: {embeddings.shape}")
print(f"Embedding dimension: {embedding_dim}")

# Visualize embedding space (2D projection using PCA)
from sklearn.decomposition import PCA

# Get embeddings for first 100 tokens
token_indices = torch.arange(100)
token_embeddings = embedding_layer(token_indices).detach().numpy()

# Apply PCA
pca = PCA(n_components=2)
embeddings_2d = pca.fit_transform(token_embeddings)

# Plot
plt.figure(figsize=(10, 8))
plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], alpha=0.6)

# Annotate some points
for i in range(0, 100, 10):
    plt.annotate(str(i), (embeddings_2d[i, 0], embeddings_2d[i, 1]))
    
plt.xlabel('First Principal Component')
plt.ylabel('Second Principal Component')
plt.title('Token Embeddings Visualization (PCA)')
plt.grid(True, alpha=0.3)
plt.show()

# Show embedding statistics
print("\nEmbedding Statistics:")
print(f"Mean: {token_embeddings.mean():.4f}")
print(f"Std: {token_embeddings.std():.4f}")
print(f"Min: {token_embeddings.min():.4f}")
print(f"Max: {token_embeddings.max():.4f}")

## 5. Positional and Segment Embeddings

In [None]:
class TransformerEmbeddings(nn.Module):
    """Complete embedding layer for transformers including token, position, and segment embeddings."""
    
    def __init__(self, vocab_size: int, d_model: int, max_seq_length: int, 
                 num_segments: int = 2, dropout: float = 0.1):
        super().__init__()
        
        # Token embeddings
        self.token_embeddings = nn.Embedding(vocab_size, d_model, padding_idx=0)
        
        # Positional embeddings (learnable)
        self.position_embeddings = nn.Embedding(max_seq_length, d_model)
        
        # Segment embeddings (for BERT-style models)
        self.segment_embeddings = nn.Embedding(num_segments, d_model)
        
        # Layer normalization and dropout
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.d_model = d_model
        
    def forward(self, input_ids: torch.Tensor, 
                segment_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
        seq_length = input_ids.size(1)
        
        # Token embeddings
        token_embeds = self.token_embeddings(input_ids)
        
        # Position embeddings
        position_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
        position_embeds = self.position_embeddings(position_ids)
        
        # Segment embeddings
        if segment_ids is None:
            segment_ids = torch.zeros_like(input_ids)
        segment_embeds = self.segment_embeddings(segment_ids)
        
        # Combine embeddings
        embeddings = token_embeds + position_embeds + segment_embeds
        
        # Apply layer norm and dropout
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        
        return embeddings

# Test transformer embeddings
vocab_size = 1000
d_model = 256
max_seq_length = 128

transformer_embeddings = TransformerEmbeddings(vocab_size, d_model, max_seq_length)

# Create sample input
batch_size = 2
seq_length = 10
input_ids = torch.randint(1, vocab_size, (batch_size, seq_length))
segment_ids = torch.cat([torch.zeros(batch_size, 5, dtype=torch.long),
                        torch.ones(batch_size, 5, dtype=torch.long)], dim=1)

# Get embeddings
embeddings = transformer_embeddings(input_ids, segment_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Output embeddings shape: {embeddings.shape}")

# Visualize different embedding components
with torch.no_grad():
    # Get individual components
    token_only = transformer_embeddings.token_embeddings(input_ids[0])
    position_ids = torch.arange(seq_length)
    position_only = transformer_embeddings.position_embeddings(position_ids)
    segment_only = transformer_embeddings.segment_embeddings(segment_ids[0])

# Plot embedding magnitudes
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Token embeddings
axes[0].imshow(token_only.numpy().T, aspect='auto', cmap='coolwarm')
axes[0].set_title('Token Embeddings')
axes[0].set_xlabel('Position')
axes[0].set_ylabel('Embedding Dimension')

# Position embeddings
axes[1].imshow(position_only.numpy().T, aspect='auto', cmap='coolwarm')
axes[1].set_title('Position Embeddings')
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Embedding Dimension')

# Segment embeddings
axes[2].imshow(segment_only.numpy().T, aspect='auto', cmap='coolwarm')
axes[2].set_title('Segment Embeddings')
axes[2].set_xlabel('Position')
axes[2].set_ylabel('Embedding Dimension')

plt.tight_layout()
plt.show()

## 6. Advanced Tokenization: WordPiece

Let's implement a simplified version of WordPiece tokenization used in BERT.

In [None]:
class WordPieceTokenizer:
    def __init__(self, vocab: List[str], unk_token: str = '[UNK]', max_word_length: int = 100):
        self.vocab = set(vocab)
        self.unk_token = unk_token
        self.max_word_length = max_word_length
        
    def tokenize(self, text: str) -> List[str]:
        """Tokenize text using WordPiece algorithm."""
        output_tokens = []
        
        # First, do basic word tokenization
        words = text.lower().split()
        
        for word in words:
            if len(word) > self.max_word_length:
                output_tokens.append(self.unk_token)
                continue
                
            is_bad = False
            sub_tokens = []
            start = 0
            
            while start < len(word):
                end = len(word)
                cur_substr = None
                
                while start < end:
                    substr = word[start:end]
                    if start > 0:
                        substr = '##' + substr
                        
                    if substr in self.vocab:
                        cur_substr = substr
                        break
                        
                    end -= 1
                    
                if cur_substr is None:
                    is_bad = True
                    break
                    
                sub_tokens.append(cur_substr)
                start = end
                
            if is_bad:
                output_tokens.append(self.unk_token)
            else:
                output_tokens.extend(sub_tokens)
                
        return output_tokens

# Create a simple vocabulary
vocab = [
    '[UNK]', '[CLS]', '[SEP]', '[PAD]', '[MASK]',
    'the', 'quick', 'brown', 'fox', 'jumps', 'over', 'lazy', 'dog',
    'machine', 'learning', 'is', 'amazing', 'transform', '##er', '##s',
    'token', '##ization', 'neural', 'network', '##ing',
    'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
    'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
    '##a', '##e', '##i', '##o', '##u', '##n', '##t', '##s', '##d', '##r'
]

wordpiece_tokenizer = WordPieceTokenizer(vocab)

# Test WordPiece tokenization
test_sentences = [
    "the quick brown fox",
    "machine learning is amazing",
    "transformers tokenization",
    "unknown words like cryptocurrency"
]

print("WordPiece Tokenization Examples:")
for sentence in test_sentences:
    tokens = wordpiece_tokenizer.tokenize(sentence)
    print(f"\nText: {sentence}")
    print(f"Tokens: {tokens}")

# Visualize subword splits
word = "tokenization"
tokens = wordpiece_tokenizer.tokenize(word)
print(f"\nSubword tokenization of '{word}': {tokens}")

## 7. Embedding Analysis and Similarity

In [None]:
def cosine_similarity(embed1: torch.Tensor, embed2: torch.Tensor) -> float:
    """Calculate cosine similarity between two embeddings."""
    dot_product = torch.sum(embed1 * embed2)
    norm1 = torch.norm(embed1)
    norm2 = torch.norm(embed2)
    return (dot_product / (norm1 * norm2)).item()

# Create embeddings for similarity analysis
vocab_size = 100
embedding_dim = 64
embedding_layer = nn.Embedding(vocab_size, embedding_dim)

# Get embeddings for some tokens
token_ids = torch.arange(20)
embeddings = embedding_layer(token_ids)

# Calculate similarity matrix
similarity_matrix = torch.zeros(20, 20)
for i in range(20):
    for j in range(20):
        similarity_matrix[i, j] = cosine_similarity(embeddings[i], embeddings[j])

# Visualize similarity matrix
plt.figure(figsize=(10, 8))
sns.heatmap(similarity_matrix.numpy(), cmap='coolwarm', center=0, 
            square=True, linewidths=0.5, cbar_kws={"shrink": 0.8})
plt.title('Token Embedding Similarity Matrix')
plt.xlabel('Token ID')
plt.ylabel('Token ID')
plt.show()

# Find most similar tokens
print("Most similar token pairs (excluding self-similarity):")
similarity_matrix_no_diag = similarity_matrix.clone()
similarity_matrix_no_diag.fill_diagonal_(-1)  # Exclude self-similarity

for _ in range(5):
    max_sim = similarity_matrix_no_diag.max()
    max_idx = similarity_matrix_no_diag.argmax()
    i = max_idx // 20
    j = max_idx % 20
    print(f"Token {i} and Token {j}: similarity = {max_sim:.4f}")
    similarity_matrix_no_diag[i, j] = -1
    similarity_matrix_no_diag[j, i] = -1

## 8. Complete Tokenization Pipeline

In [None]:
class TokenizationPipeline:
    """Complete tokenization pipeline for transformers."""
    
    def __init__(self, vocab_file: Optional[str] = None, vocab_size: int = 1000):
        self.vocab_size = vocab_size
        self.special_tokens = {
            'pad_token': '[PAD]',
            'unk_token': '[UNK]',
            'cls_token': '[CLS]',
            'sep_token': '[SEP]',
            'mask_token': '[MASK]'
        }
        
        # Build vocabulary
        if vocab_file:
            self.load_vocab(vocab_file)
        else:
            self.build_default_vocab()
            
    def build_default_vocab(self):
        """Build a default vocabulary."""
        self.token_to_id = {}
        self.id_to_token = {}
        
        # Add special tokens
        for i, (_, token) in enumerate(self.special_tokens.items()):
            self.token_to_id[token] = i
            self.id_to_token[i] = token
            
        print(f"Vocabulary initialized with {len(self.token_to_id)} special tokens")
        
    def preprocess_text(self, text: str) -> str:
        """Preprocess text before tokenization."""
        # Convert to lowercase
        text = text.lower()
        
        # Remove extra whitespace
        text = ' '.join(text.split())
        
        return text
    
    def encode(self, text: str, max_length: Optional[int] = None,
               truncation: bool = True, padding: bool = True) -> Dict[str, torch.Tensor]:
        """Encode text to model inputs."""
        # Preprocess
        text = self.preprocess_text(text)
        
        # Simple word tokenization for demo
        tokens = text.split()
        
        # Add special tokens
        tokens = [self.special_tokens['cls_token']] + tokens + [self.special_tokens['sep_token']]
        
        # Convert to ids
        input_ids = []
        for token in tokens:
            if token in self.token_to_id:
                input_ids.append(self.token_to_id[token])
            else:
                input_ids.append(self.token_to_id[self.special_tokens['unk_token']])
                
        # Truncation
        if max_length and truncation and len(input_ids) > max_length:
            input_ids = input_ids[:max_length]
            
        # Create attention mask
        attention_mask = [1] * len(input_ids)
        
        # Padding
        if max_length and padding:
            padding_length = max_length - len(input_ids)
            input_ids = input_ids + [self.token_to_id[self.special_tokens['pad_token']]] * padding_length
            attention_mask = attention_mask + [0] * padding_length
            
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(attention_mask)
        }
    
    def batch_encode(self, texts: List[str], max_length: Optional[int] = None,
                    truncation: bool = True, padding: bool = True) -> Dict[str, torch.Tensor]:
        """Encode multiple texts."""
        encoded_batch = [self.encode(text, max_length, truncation, False) for text in texts]
        
        # Find max length in batch
        if padding and not max_length:
            max_length = max(len(enc['input_ids']) for enc in encoded_batch)
            
        # Pad to max length
        if padding:
            for enc in encoded_batch:
                padding_length = max_length - len(enc['input_ids'])
                if padding_length > 0:
                    enc['input_ids'] = torch.cat([
                        enc['input_ids'],
                        torch.full((padding_length,), self.token_to_id[self.special_tokens['pad_token']])
                    ])
                    enc['attention_mask'] = torch.cat([
                        enc['attention_mask'],
                        torch.zeros(padding_length, dtype=torch.long)
                    ])
                    
        # Stack into batch
        return {
            'input_ids': torch.stack([enc['input_ids'] for enc in encoded_batch]),
            'attention_mask': torch.stack([enc['attention_mask'] for enc in encoded_batch])
        }

# Test the complete pipeline
pipeline = TokenizationPipeline()

# Single text encoding
text = "Hello, this is a test sentence!"
encoded = pipeline.encode(text, max_length=15)
print("Single text encoding:")
print(f"Text: {text}")
print(f"Input IDs: {encoded['input_ids']}")
print(f"Attention Mask: {encoded['attention_mask']}")

# Batch encoding
texts = [
    "This is the first sentence.",
    "This is a much longer second sentence that might need truncation.",
    "Short one."
]

batch_encoded = pipeline.batch_encode(texts, max_length=20)
print("\nBatch encoding:")
print(f"Batch Input IDs shape: {batch_encoded['input_ids'].shape}")
print(f"Batch Attention Mask shape: {batch_encoded['attention_mask'].shape}")

# Visualize batch
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Input IDs
im1 = ax1.imshow(batch_encoded['input_ids'].numpy(), aspect='auto', cmap='viridis')
ax1.set_title('Batch Input IDs')
ax1.set_xlabel('Token Position')
ax1.set_ylabel('Batch Index')
plt.colorbar(im1, ax=ax1)

# Attention Mask
im2 = ax2.imshow(batch_encoded['attention_mask'].numpy(), aspect='auto', cmap='RdBu')
ax2.set_title('Batch Attention Mask')
ax2.set_xlabel('Token Position')
ax2.set_ylabel('Batch Index')
plt.colorbar(im2, ax=ax2)

plt.tight_layout()
plt.show()

## Summary

In this notebook, we've explored tokenization and embeddings in depth:

1. **Tokenization Levels**: Character, word, and subword tokenization
2. **Simple Tokenizer**: Built a basic word-level tokenizer
3. **BPE Implementation**: Implemented Byte-Pair Encoding from scratch
4. **WordPiece**: Implemented simplified WordPiece tokenization
5. **Embeddings**: Explored token, positional, and segment embeddings
6. **Similarity Analysis**: Analyzed embedding relationships
7. **Complete Pipeline**: Built a full tokenization pipeline

Key takeaways:
- Tokenization is crucial for converting text to numerical form
- Subword tokenization balances vocabulary size and coverage
- Embeddings provide dense representations of discrete tokens
- Position and segment embeddings add structural information
- A good tokenization pipeline handles various edge cases

Next, we'll explore how to train transformers effectively!