# Vocabulary (Tokenizer) Extension/Expansion/Adaptation

## Prepare Data

In [1]:
from datasets import load_dataset

# this dataset has already fixed encoding using ftfy (as is used by me in the preprocessing steps of other datasets)
dataset = load_dataset("HuggingFaceFW/fineweb-2", "ces_Latn", split="train", streaming=True)
dataset

Resolving data files:   0%|          | 0/25 [00:00<?, ?it/s]

IterableDataset({
    features: ['text', 'id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'language_script', 'minhash_cluster_size', 'top_langs'],
    n_shards: 25
})

In [2]:
#we need only texts
dataset = dataset.remove_columns(["id", "dump", "url", "date", "file_path", "language", "language_score", "language_script", "minhash_cluster_size", "top_langs"])
dataset

IterableDataset({
    features: ['text'],
    n_shards: 25
})

In [3]:
#shuffle to be sure we select "random sample"
dataset = dataset.shuffle(seed=42, buffer_size=10000)

In [4]:
#limit the number of samples
dataset = dataset.take(100000)
dataset

IterableDataset({
    features: ['text'],
    n_shards: 25
})

In [5]:
#prepare iterator that outputs only texts
def serve_texts():
    for example in dataset:
        yield example["text"]

## Load the Original Tokenizer

In [6]:
from transformers import AutoTokenizer

old_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")

In [7]:
example = "Řeřicha je květina vyskutjící se v Česku a na Moravě."
old_tokenizer.tokenize(example)

['Åĺ',
 'e',
 'ÅĻ',
 'icha',
 'Ġje',
 'ĠkvÄĽt',
 'ina',
 'Ġvys',
 'k',
 'ut',
 'jÃŃcÃŃ',
 'Ġse',
 'Ġv',
 'ĠÄĮes',
 'ku',
 'Ġa',
 'Ġna',
 'ĠMor',
 'avÄĽ',
 '.']

## Train new Tokenizer

In [8]:
#25000 based on the SambaLingo paper
tokenizer = old_tokenizer.train_new_from_iterator(serve_texts(), 25000)






In [9]:
tokenizer.tokenize(example)

['Åĺe',
 'ÅĻi',
 'cha',
 'Ġje',
 'ĠkvÄĽt',
 'ina',
 'Ġvy',
 'sku',
 't',
 'jÃŃcÃŃ',
 'Ġse',
 'Ġv',
 'ĠÄĮesku',
 'Ġa',
 'Ġna',
 'ĠMoravÄĽ',
 '.']

## Tokenizer Merging

In [10]:
old_tokenizer.vocab_size, tokenizer.vocab_size

(128000, 25000)

In [14]:
old_vocab = set(old_tokenizer.get_vocab())
new_vocab = set(tokenizer.get_vocab())


diff_vocab = new_vocab - old_vocab
diff_vocab, len(diff_vocab)


({'Ġnasy',
  'Ġvstupenky',
  'ÅĻadu',
  'Ġvce',
  'teÄįnÃ½m',
  'Ġdezer',
  'ĠaÄį',
  'ĠpÅĻedstaven',
  'spod',
  'ĠudÄĽlala',
  'venÃŃ',
  'ĠpÄĽknÄĽ',
  'ĠdÃŃlny',
  'hovÃ©',
  'ĠDisku',
  'ĠtÄĽÅ¾kÃ¡',
  'ĠstÅĻihu',
  'Ġpum',
  'ĠsouÄįasnÄĽ',
  'ĠdobrÃ½m',
  '|Hmotnost',
  'Ġkoru',
  'Ġosoba',
  'Ġspustit',
  'ĠÄįernÃ½ch',
  'ĠsfÃ©',
  'manÅ¯v',
  'Ġrytmu',
  'Å¾Ã¡d',
  'fikovanÃ©',
  'ĠÃºplnÃ©',
  'VÃŃce',
  'ĠvÄĽnovat',
  'ĠprotÃ¡',
  'Ġkonzu',
  'Ġkoleg',
  'Ġmaleb',
  'Ġkonstrukce',
  'ĠslouÄįen',
  'ĠpraktickÃ½ch',
  'HlavnÃŃ',
  'ĠNedo',
  'ĠnÃ¡poj',
  'ntÃ¡l',
  'nejlepÅ¡ÃŃ',
  'Ġnejmlad',
  'tina',
  'ĠMys',
  'ĠnevÃ¡',
  'ĠprÃŃjmu',
  'uret',
  'ĠpodnikatelskÃ©',
  'ĠexistujÃŃ',
  'ĠmÅ¯Å¾ou',
  'ĠÅ¾ijÃŃcÃŃ',
  'ĠTereza',
  'Å¾uje',
  'ĠjÃŃdlo',
  'ramek',
  'ĠvÅ¯nÄĽ',
  'ĠnÃ¡vÅ¡tÄĽva',
  'ĠmodernÃŃch',
  'ĠdÅ¯sledku',
  'ovanÃ½m',
  'Å½i',
  'Ġpodpis',
  'hlo',
  'Ġpasu',
  'ĠzÃ¡ru',
  'ĠvÃ½sledkÅ¯',
  'Ġpotravin',
  'Äįany',
  'riginÃ¡lnÃŃ',
  'Ġjednomu',
  'ÅĪujÃŃcÃŃ',
  'Ġ

In [13]:
new_tokens = list(diff_vocab)
new_tokens

['Ġnasy',
 'Ġvstupenky',
 'ÅĻadu',
 'Ġvce',
 'teÄįnÃ½m',
 'Ġdezer',
 'ĠaÄį',
 'ĠpÅĻedstaven',
 'spod',
 'ĠudÄĽlala',
 'venÃŃ',
 'ĠpÄĽknÄĽ',
 'ĠdÃŃlny',
 'hovÃ©',
 'ĠDisku',
 'ĠtÄĽÅ¾kÃ¡',
 'ĠstÅĻihu',
 'Ġpum',
 'ĠsouÄįasnÄĽ',
 'ĠdobrÃ½m',
 '|Hmotnost',
 'Ġkoru',
 'Ġosoba',
 'Ġspustit',
 'ĠÄįernÃ½ch',
 'ĠsfÃ©',
 'manÅ¯v',
 'Ġrytmu',
 'Å¾Ã¡d',
 'fikovanÃ©',
 'ĠÃºplnÃ©',
 'VÃŃce',
 'ĠvÄĽnovat',
 'ĠprotÃ¡',
 'Ġkonzu',
 'Ġkoleg',
 'Ġmaleb',
 'Ġkonstrukce',
 'ĠslouÄįen',
 'ĠpraktickÃ½ch',
 'HlavnÃŃ',
 'ĠNedo',
 'ĠnÃ¡poj',
 'ntÃ¡l',
 'nejlepÅ¡ÃŃ',
 'Ġnejmlad',
 'tina',
 'ĠMys',
 'ĠnevÃ¡',
 'ĠprÃŃjmu',
 'uret',
 'ĠpodnikatelskÃ©',
 'ĠexistujÃŃ',
 'ĠmÅ¯Å¾ou',
 'ĠÅ¾ijÃŃcÃŃ',
 'ĠTereza',
 'Å¾uje',
 'ĠjÃŃdlo',
 'ramek',
 'ĠvÅ¯nÄĽ',
 'ĠnÃ¡vÅ¡tÄĽva',
 'ĠmodernÃŃch',
 'ĠdÅ¯sledku',
 'ovanÃ½m',
 'Å½i',
 'Ġpodpis',
 'hlo',
 'Ġpasu',
 'ĠzÃ¡ru',
 'ĠvÃ½sledkÅ¯',
 'Ġpotravin',
 'Äįany',
 'riginÃ¡lnÃŃ',
 'Ġjednomu',
 'ÅĪujÃŃcÃŃ',
 'ĠspadÃ¡',
 'ĠpouliÄįnÃŃ',
 'Ġsvita',
 'Ġletovice',
 'ĠÄįela',
 'chne',
 'nick

In [15]:
#add the difference between vocabularies (creating union of the two vocabularies)
num_added_toks = old_tokenizer.add_tokens(new_tokens)

In [24]:
num_added_toks

17017

In [23]:
print(len(old_tokenizer))

145273


In [25]:
old_tokenizer.save_pretrained("models/Llama-3.2-3B-Instruct-cs_expanded")

('models/Llama-3.2-3B-Instruct-cs_expanded/tokenizer_config.json',
 'models/Llama-3.2-3B-Instruct-cs_expanded/special_tokens_map.json',
 'models/Llama-3.2-3B-Instruct-cs_expanded/tokenizer.json')

## Resize the Token embeddings

In [None]:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, PreTrainedModel
from typing import Union, Optional
import torch
from torch import nn

def resize_model_embeddings(model: PreTrainedModel,
                            old_tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 
                            new_tokenizer:Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 
                            type: Optional[str] = "mean_resizing"):
    assert type in ["mean_resizing", "subword_resizing", None]

    if type == "mean_resizing":
        #https://nlp.stanford.edu/~johnhew/vocab-expansion.html
        model.resize_token_embeddings(len(new_tokenizer), mean_resizing=True)
    elif type == "subword_resizing":
        #obtain original embeddings
        original_input_embeddings = model.get_input_embeddings()

        #get new tokens, ordered by ids
        new_tokens = list(set(new_tokenizer.get_vocab().keys()) - set(old_tokenizer.get_vocab().keys()))
        new_vocab = new_tokenizer.get_vocab()
        new_tokens = sorted(new_tokens, key=lambda x: new_vocab[x])
        
        #prepare new embeddings for extension
        added_embeddings = []
        for token in new_tokens:
            #get input ids in the original tokenizer
            input_ids = old_tokenizer.encode(token, add_special_tokens=False, return_tensors="pt")[0]

            #select corresponding embeddings from the original ones
            embeddings = original_input_embeddings(input_ids)
            #compute the mean
            mean_embedding = embeddings.mean(dim=0)

            #save the new embedding
            added_embeddings.append(mean_embedding)

        added_embeddings = torch.stack(added_embeddings).to(original_input_embeddings.weight.dtype)
        
        #prepare new embedding layer
        old_num_tokens, old_embedding_dim = original_input_embeddings.weight.shape

        new_num_tokens = old_num_tokens + added_embeddings.shape[0]
        new_embeddings = nn.Embedding(
                    new_num_tokens,
                    old_embedding_dim,
                    device=original_input_embeddings.weight.device,
                    dtype=original_input_embeddings.weight.dtype,
                )
        
        #copy the original embeddings
        new_embeddings.weight.data[:old_num_tokens] = original_input_embeddings.weight
        #initialize the new embeddings
        new_embeddings.weight.data[old_num_tokens:] = added_embeddings
        #set gradient requirement
        new_embeddings.requires_grad_(original_input_embeddings.weight.requires_grad)

        #set the new embeddings to the model
        model.set_input_embeddings(new_embeddings)

        #check if the model has tied weights in ouput layer
        if not model.config.get_text_config(decoder=True).tie_word_embeddings:
            raise ValueError("The model does not have tied weights in the output layer. This is currently not supported by the subword resizing.")
        
        vocab_size = model.get_input_embeddings().weight.shape[0]

        model.config.get_text_config().vocab_size = vocab_size
        model.vocab_size = vocab_size

        model.tie_weights()
        
    else:
        model.resize_token_embeddings(len(new_tokenizer), mean_resizing=False)

    

In [11]:
from transformers import AutoTokenizer, AutoModelForCausalLM

old_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
new_tokenizer = AutoTokenizer.from_pretrained("models/Llama-3.2-3B-Instruct-cs_expanded")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")

resize_model_embeddings(model, old_tokenizer, new_tokenizer, type="subword_resizing")
model

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

['Ġnasy', 'Ġvstupenky', 'ÅĻadu', 'Ġvce', 'teÄįnÃ½m', 'Ġdezer', 'ĠaÄį', 'ĠpÅĻedstaven', 'spod', 'ĠudÄĽlala', 'venÃŃ', 'ĠpÄĽknÄĽ', 'ĠdÃŃlny', 'hovÃ©', 'ĠDisku', 'ĠtÄĽÅ¾kÃ¡', 'ĠstÅĻihu', 'Ġpum', 'ĠsouÄįasnÄĽ', 'ĠdobrÃ½m', '|Hmotnost', 'Ġkoru', 'Ġosoba', 'Ġspustit', 'ĠÄįernÃ½ch', 'ĠsfÃ©', 'manÅ¯v', 'Ġrytmu', 'Å¾Ã¡d', 'fikovanÃ©', 'ĠÃºplnÃ©', 'VÃŃce', 'ĠvÄĽnovat', 'ĠprotÃ¡', 'Ġkonzu', 'Ġkoleg', 'Ġmaleb', 'Ġkonstrukce', 'ĠslouÄįen', 'ĠpraktickÃ½ch', 'HlavnÃŃ', 'ĠNedo', 'ĠnÃ¡poj', 'ntÃ¡l', 'nejlepÅ¡ÃŃ', 'Ġnejmlad', 'tina', 'ĠMys', 'ĠnevÃ¡', 'ĠprÃŃjmu', 'uret', 'ĠpodnikatelskÃ©', 'ĠexistujÃŃ', 'ĠmÅ¯Å¾ou', 'ĠÅ¾ijÃŃcÃŃ', 'ĠTereza', 'Å¾uje', 'ĠjÃŃdlo', 'ramek', 'ĠvÅ¯nÄĽ', 'ĠnÃ¡vÅ¡tÄĽva', 'ĠmodernÃŃch', 'ĠdÅ¯sledku', 'ovanÃ½m', 'Å½i', 'Ġpodpis', 'hlo', 'Ġpasu', 'ĠzÃ¡ru', 'ĠvÃ½sledkÅ¯', 'Ġpotravin', 'Äįany', 'riginÃ¡lnÃŃ', 'Ġjednomu', 'ÅĪujÃŃcÃŃ', 'ĠspadÃ¡', 'ĠpouliÄįnÃŃ', 'Ġsvita', 'Ġletovice', 'ĠÄįela', 'chne', 'nickou', 'Ġrozli', 'ĠodhalÃŃ', 'ĠDÃ¡n', 'Ġmamin', 'ĠobhÃ¡', 'ĠpÅĻejÃŃ', 'anis', 'ĠzÃ¡

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(145273, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm