<a href="https://colab.research.google.com/github/muchad/VocabPrune/blob/main/VocabPrune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VocabPrune

This notebook implements the deterministic language-aware frequency-based vocabulary pruning strategy for the mDeBERTa-v3-base model.

The workflow:
1.  **Corpus Selection:** Using parallel corpora (Indonesian & English) to estimate token utility.
2.  **Frequency Analysis:** utilizing the original tokenizer to count subword occurrences.
3.  **Language-Aware Pruning:** Constructing a hybrid vocabulary with a fixed allocation ratio while ensuring disjoint selection to maximize coverage.
4.  **Reconstruction:** Re-indexing the embedding matrix and reconstructing the tokenizer to be a drop-in replacement for the original model.


Github: https://github.com/muchad/VocabPrune/

In [None]:
# @title Configuration & Corpus Selection

# Model Checkpoint
# https://huggingface.co/microsoft/mdeberta-v3-base
MODEL_CKPT = "microsoft/mdeberta-v3-base"

# Pruning Parameters
TARGET_VOCAB_SIZE = 30000   # Target final vocabulary size
INDO_RATIO = 0.7            # Allocation for Indonesian tokens (70%)
ENG_RATIO = 0.3             # Allocation for English tokens (30%)

# Dataset URLs (Leipzig Corpora)
INDO_CORPUS_URL = "https://downloads.wortschatz-leipzig.de/corpora/ind-id_web-public_2017_1M.tar.gz"
ENG_CORPUS_URL = "https://downloads.wortschatz-leipzig.de/corpora/eng-com_web-public_2018_1M.tar.gz"

# Output Directory
OUTPUT_DIR = f"mdeberta-hybrid-{int(TARGET_VOCAB_SIZE/1000)}k"

print(f"Configuration Set:")
print(f"Model: {MODEL_CKPT}")
print(f"Target Vocab: {TARGET_VOCAB_SIZE}")
print(f"Ratio: {INDO_RATIO*100}% ID / {ENG_RATIO*100}% EN")

In [None]:
# @title Install Libraries & Download Data
import os
import glob

# 1. Install Dependencies
!pip install -q transformers sentencepiece tqdm

# 2. Download & Extract Dataset
def download_corpus(url, folder_name):
    if not os.path.exists(folder_name):
        print(f"Downloading corpus to {folder_name}...")
        filename = url.split("/")[-1]
        !wget -q {url}
        !mkdir -p {folder_name}
        !tar -xzvf {filename} -C {folder_name}
        print("Done.")
    else:
        print(f"Folder {folder_name} already exists.")

download_corpus(INDO_CORPUS_URL, "ind_data")
download_corpus(ENG_CORPUS_URL, "eng_data")

# 3. Locate Text Files
ind_file = glob.glob("ind_data/**/*sentences.txt")[0]
eng_file = glob.glob("eng_data/**/*sentences.txt")[0]

print(f"\nIndonesian Corpus: {ind_file}")
print(f"English Corpus:    {eng_file}")

In [None]:
# @title Subword Tokenization and Frequency Analysis
import torch
from transformers import DebertaV2TokenizerFast
from collections import Counter
from tqdm.auto import tqdm

print(f"Loading tokenizer from {MODEL_CKPT}...")
tokenizer = DebertaV2TokenizerFast.from_pretrained(MODEL_CKPT)
original_vocab_size = tokenizer.vocab_size
print(f"Original Vocab Size: {original_vocab_size}")

def count_tokens(filename, tokenizer, limit_lines=None, batch_size=1000):
    """Reads file and counts subword token usage using the original tokenizer."""
    counter = Counter()
    with open(filename, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        if limit_lines:
            lines = lines[:limit_lines]

    print(f"Processing {filename.split('/')[-1]}...")
    for i in tqdm(range(0, len(lines), batch_size)):
        batch = [line.strip().split('\t')[1] if '\t' in line else line.strip() for line in lines[i:i+batch_size]]
        # Tokenize (we only need input_ids)
        encodings = tokenizer(batch, add_special_tokens=False)
        for ids in encodings.input_ids:
            counter.update(ids)
    return counter

# Execute Counting
cnt_ind = count_tokens(ind_file, tokenizer)
cnt_eng = count_tokens(eng_file, tokenizer)

print("Frequency analysis complete.")

In [None]:
# @title Language-Aware Vocabulary Pruning
# 1. Calculate Allocation
special_tokens = tokenizer.all_special_ids
num_special = len(special_tokens)
available_slots = TARGET_VOCAB_SIZE - num_special

target_indo = int(available_slots * INDO_RATIO)
target_eng = int(available_slots * ENG_RATIO)

print(f"Allocation -> Special: {num_special}, Indo: {target_indo}, Eng: {target_eng}")

# 2. Vocabulary Selection Strategy
keep_ids = set(special_tokens)

# Step A: Priority Allocation (Indonesian)
most_common_ind = cnt_ind.most_common(target_indo)
ind_ids = [t[0] for t in most_common_ind]
keep_ids.update(ind_ids)

# Step B: Disjoint Selection (English)
# Only add English tokens that are NOT already covered by the Indonesian set
current_count = len(keep_ids)
remaining_slots = TARGET_VOCAB_SIZE - current_count

most_common_eng = cnt_eng.most_common(len(cnt_eng))
added_eng = 0
for t_id, _ in most_common_eng:
    if added_eng >= remaining_slots:
        break
    if t_id not in keep_ids:
        keep_ids.add(t_id)
        added_eng += 1

# 3. Create Mappings
sorted_keep_ids = sorted(list(keep_ids))
old_to_new_map = {old_id: new_id for new_id, old_id in enumerate(sorted_keep_ids)}
new_to_old_map = {new_id: old_id for new_id, old_id in enumerate(sorted_keep_ids)}

print(f"\nFinal Vocab Size: {len(sorted_keep_ids)}")

In [None]:
# @title Embedding and Tokenizer Reconstruction
import json
import shutil
from transformers import AutoModel, AutoConfig

print("1. Pruning Embedding Matrix (Weight Transfer)")
# Load original model
model = AutoModel.from_pretrained(MODEL_CKPT)
old_embeddings = model.embeddings.word_embeddings.weight.data
new_vocab_size = len(sorted_keep_ids)
embedding_dim = old_embeddings.shape[1]

# Initialize new embedding matrix
new_embeddings = torch.zeros((new_vocab_size, embedding_dim))

# Copy weights based on mapping
for new_id, old_id in tqdm(new_to_old_map.items(), desc="Copying Weights"):
    new_embeddings[new_id] = old_embeddings[old_id]

# Assign to model
model.embeddings.word_embeddings.weight = torch.nn.Parameter(new_embeddings)
model.config.vocab_size = new_vocab_size
model.embeddings.word_embeddings.weight.data = new_embeddings

print("\n2. Reconstructing Tokenizer JSON")
if os.path.exists("temp_tokenizer"): shutil.rmtree("temp_tokenizer")
tokenizer.save_pretrained("temp_tokenizer")

with open("temp_tokenizer/tokenizer.json", "r", encoding="utf-8") as f:
    tokenizer_data = json.load(f)

old_vocab = tokenizer_data["model"]["vocab"]
new_vocab_list = []
is_list_vocab = isinstance(old_vocab, list)

# Rebuild vocab list preserving SentencePiece structure
for new_id in range(new_vocab_size):
    old_id = new_to_old_map[new_id]
    if is_list_vocab:
        try:
            vocab_entry = old_vocab[old_id]
            new_vocab_list.append(vocab_entry)
        except IndexError:
            # Fallback for special tokens if index mismatch occurs
            token_str = tokenizer.convert_ids_to_tokens(old_id)
            new_vocab_list.append([token_str, 0.0])

if is_list_vocab:
    tokenizer_data["model"]["vocab"] = new_vocab_list

# Reset added_tokens to prevent index conflicts
tokenizer_data["added_tokens"] = []

# Save
if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR)
model.save_pretrained(OUTPUT_DIR)

with open(f"{OUTPUT_DIR}/tokenizer.json", "w", encoding="utf-8") as f:
    json.dump(tokenizer_data, f, ensure_ascii=False)

for f_name in ["tokenizer_config.json", "special_tokens_map.json"]:
    try: shutil.copy(f"temp_tokenizer/{f_name}", f"{OUTPUT_DIR}/{f_name}")
    except: pass

print(f"Pruned model saved to: {OUTPUT_DIR}/")
!ls -lh {OUTPUT_DIR}

In [None]:
# @title Verify & Compare
from transformers import AutoTokenizer

print("Loading Pruned Model for Verification...")
try:
    new_tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)
    new_model = AutoModel.from_pretrained(OUTPUT_DIR)
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")

# Comparison Test
text = "Pemerintah sedang mengimplementasikan framework Artificial Intelligence."

print("\n--- Tokenization Comparison ---")
# Original
orig_tok = DebertaV2TokenizerFast.from_pretrained(MODEL_CKPT)
orig_tokens = orig_tok.tokenize(text)
print(f"[Original] Tokens: {len(orig_tokens)} | {orig_tokens}")

# Pruned
pruned_tokens = new_tokenizer.tokenize(text)
print(f"[Pruned]   Tokens: {len(pruned_tokens)} | {pruned_tokens}")

diff = len(pruned_tokens) - len(orig_tokens)
print(f"\nDifference in length: {diff} tokens")