## Using SentencePiece to Train a Tokenizer on a mini-batch of data from enwikisource

In [83]:
# Import dependencies
import sentencepiece as sp
import pandas as pd
import torch
import os
from collections import Counter

In [5]:
# Link to data and set var for model prefix
input_file = '../../data/enwiki_20240320_minibatch.txt'
model_prefix = '../models/sptokenizer_16384'

In [49]:
# Train model
sp.SentencePieceTrainer.train(input=input_file,
                               model_prefix=model_prefix,
                               vocab_size=16384,
                               character_coverage=1.0,
                               model_type='bpe')

print(f'Model and vocabulary have been generated: {model_prefix}.model and {model_prefix}.vocab')

Model and vocabulary have been generated: ../models/sptokenizer_16384.model and ../models/sptokenizer_16384.vocab


## Compute distribution of tokens across our training data
---
Need to get a token count so that we can conduct effective negative sampling

In [50]:
# First load our model
model_path = "../models/sptokenizer_16384.model"
model = sp.SentencePieceProcessor()
model.load(model_path)

# Testing our model
model.decode(model.encode("Hello world!"))

'Hello world!'

In [55]:
# Now - we need to iterate through our data and get a count for the number of times each token is seen
def count_tokens(data_file, spm_model):
    token_counts = Counter()
    with open(data_file, 'r', encoding='utf-8') as file:
        for line in file:
            line = line.strip()
            if line:
                token_counts.update(spm_model.encode(line))
    return token_counts

In [82]:
token_counts = count_tokens(input_file, spm_model=model)

# Print out our top sampled IDs
sorted_counts = sorted(token_counts.items(), key=lambda item: item[1], reverse=True)
print(f"ID     Count    Word\n{'-'*24}")
for i in range(50):
    idx, count = sorted_counts[i]
    print(f"{idx:>5}: {count:>6} - \"{model.decode(idx)}\"")

ID     Count    Word
------------------------
16210: 215732 - ","
    6: 205933 - "the"
   16: 129767 - "of"
   21: 110446 - "and"
16213: 109705 - "."
   32:  80859 - "to"
   33:  55222 - "in"
    5:  47629 - "a"
   58:  40097 - "that"
   74:  28892 - "is"
   48:  28110 - "I"
   44:  27899 - "be"
   65:  27017 - "for"
16220:  25077 - ";"
   83:  22347 - "it"
   87:  20951 - "with"
   93:  20755 - "not"
   89:  19772 - "as"
16224:  17964 - ":"
  111:  17868 - "by"
  110:  17845 - "his"
   63:  17648 - "he"
   98:  17068 - "or"
  119:  16701 - "shall"
  121:  15817 - "was"
  128:  15696 - "which"
16197:  15200 - "s"
  133:  15194 - "have"
  130:  14607 - "all"
  142:  14390 - "they"
  153:  13490 - "And"
  147:  13462 - "are"
  101:  13419 - "The"
   88:  13365 - "on"
  157:  13247 - "from"
  159:  13205 - "this"
  162:  12755 - "their"
16232:  12554 - "-"
  154:  12234 - "them"
16198:  12036 - "h"
  170:  11516 - "will"
   91:  10939 - "we"
  173:  10792 - "my"
16219:  10286 - "'"
  156

In [100]:
# Now - convert our counter dict to a distribution
clipped_freq = []
for i in range(16834):
    clipped_freq.append(token_counts[i] if token_counts[i] > 10 else 0)

frequencies_tensor = torch.tensor(clipped_freq, dtype=torch.float)
frequencies_tensor = frequencies_tensor / frequencies_tensor.sum()
torch.save(frequencies_tensor, '../token_distribution/frequencies_16384.pt')

## Inspect Vocabulary
---

In [101]:
def load_and_print_vocab_samples(vocab_file, start_index=0, num_samples=10):
    """
    Load vocabulary from a SentencePiece .vocab file and print a specified number of samples
    starting from a specified index.
    
    :param vocab_file: Path to the SentencePiece .vocab file
    :param start_index: Index to start printing samples from
    :param num_samples: Number of vocabulary entries to print
    """
    with open(vocab_file, 'r', encoding='utf-8') as f:
        vocab = [line.split('\t')[0] for line in f.readlines()]  # Extract tokens
    
    # Ensure start_index and num_samples are within bounds
    end_index = min(start_index + num_samples, len(vocab))
    
    # Print specified samples
    for i in range(start_index, end_index):
        print(f'Index {i}: {vocab[i]}')

In [102]:
vocab_file = '../models/sptokenizer_16384.vocab'
load_and_print_vocab_samples(vocab_file, start_index=14000, num_samples=256)

Index 14000: ▁Macondo
Index 14001: ▁Persian
Index 14002: ▁Shortly
Index 14003: ▁acquies
Index 14004: ▁biggest
Index 14005: ▁bullets
Index 14006: ▁carpent
Index 14007: ▁ceasing
Index 14008: ▁conject
Index 14009: ▁consign
Index 14010: ▁conspic
Index 14011: ▁earthly
Index 14012: ▁highway
Index 14013: ▁induced
Index 14014: ▁insight
Index 14015: ▁lecture
Index 14016: ▁nourish
Index 14017: ▁nowhere
Index 14018: ▁parched
Index 14019: ▁plaster
Index 14020: ▁relapse
Index 14021: ▁shorter
Index 14022: ▁slender
Index 14023: ▁sounded
Index 14024: ▁spotted
Index 14025: ▁stature
Index 14026: ▁swiftly
Index 14027: ▁uncover
Index 14028: etermined
Index 14029: ▁Building
Index 14030: ▁Chambers
Index 14031: ▁Paradise
Index 14032: ▁Persians
Index 14033: ▁Zubaydah
Index 14034: ▁accursed
Index 14035: ▁apostles
Index 14036: ▁brighter
Index 14037: ▁coercion
Index 14038: ▁counting
Index 14039: ▁depended
Index 14040: ▁embodied
Index 14041: ▁hallowed
Index 14042: ▁heartily
Index 14043: ▁hijacker
Index 14044: ▁in