# Embedding Trainer
---
Main script to train our various embedding models. What we'll do is import our 'Embedding Model' objects that we'll define in other files - and this code will be the training loop and will save the results to various folders.

## Create our Dataloader and Tokenizer
---
We already defined these and trained the tokenizer in the other files. So, let's go ahead and instantiate these as a first step

### Import Dependencies

In [1]:
# To allow for easy access of other packages in this directory, let's first nav to the project root
import sys
import os
project_root = os.path.dirname(os.getcwd())
sys.path.append(project_root)

In [2]:
from dataloader.dataloader import MyDataLoader
from tokenizer.tokenizer import MyTokenizer

### Instantiate our Dataloader

In [3]:
dl = MyDataLoader(promptuser=False, batch_size=1, shuffle=True)    # By setting promptuser=False, we just use the 'enwiki_articles_20240320_mini' dataset (50MB)
train_dataloader = dl.get_train_dataloader()
test_dataloader = dl.get_test_dataloader()

for batch in train_dataloader:
    sample_data = batch[0][0]
    break

print(f"Number of chars: {len(sample_data)}")
print(f"{'='*60}")
print(sample_data[:500])

Number of chars: 44986
Mr. President and fellow citizens of New York: -

The facts with which I shall deal this evening are mainly old and familiar; nor is there anything new in the general use I shall make of them. If there shall be any novelty, it will be in the mode of presenting the facts, and the inferences and observations following that presentation.

In his speech last autumn, at Columbus, Ohio, as reported in "The New-York Times," Senator Douglas said:

"Our fathers, when they framed the Government under whic


In [4]:
dl.print_samples(num_samples=3)

Sample #1
Number of chars: 10230
<h2> Case document</h2>
1.	These proceedings are now cited as Re Kevin : Validity of Marriage of Transsexual (2001) FamCA 1074 and (2001) FLC 93-087 ("Re Kevin"). Justice Chisholm's original decision, granting a Declaration of Validity of Marriage in

Sample #2
Number of chars: 8842
, , , ,  distinguished guests and my fellow citizens:

The peaceful transfer of authority is rare in history, yet common in our country. With a simple oath, we affirm old traditions and make new beginnings.

As I begin, I thank President Clinton for 

Sample #3
Number of chars: 15567
They were not long in reaching the barracks, for the officer who 
commanded the party was desirous to avoid rousing the people by the 
display of military force in the streets, and was humanely anxious 
to give as little opportunity as possible for a

Test Dataset Samples:
Number of chars: 1548
Shrines. Come along. It's rather picturesque. A variant on Velasquez's ''Les Lanzas.'''

Reluctantly J

### Instantiate our Tokenizer

In [5]:
tokenizer = MyTokenizer()
tk_vocab_size = tokenizer.get_vocab_size()

# Ensure our tokenizer is running properly
chars_to_print = 200
print(tokenizer.encode_as_ids(sample_data[:chars_to_print]))
print(f"{'-'*60}")
print(tokenizer.encode_as_pieces(sample_data[:chars_to_print]))
print(f"{'-'*60}")
print(tokenizer.decode(tokenizer.encode_as_ids(sample_data[:chars_to_print])))

[136, 186, 200, 137, 131, 66, 47, 21, 23, 178, 46, 84, 22, 17, 183, 235, 20, 184, 16, 177, 226, 178, 194, 177, 248, 26, 201, 211, 177, 219, 101, 23, 180, 64, 184, 87, 128, 48, 119, 116, 34, 159, 49, 199, 20, 36, 147, 24, 115, 68, 8, 73, 21, 23, 90, 100, 183, 37, 207, 41, 26, 74, 6, 9, 96, 195, 117, 36, 172, 194, 33, 6, 60, 20, 13, 34, 54, 56, 48, 119, 24, 132, 178, 16, 154, 200, 48, 191, 6, 186]
------------------------------------------------------------
['▁M', 'r', '.', '▁P', 'res', 'id', 'ent', '▁and', '▁f', 'e', 'll', 'ow', '▁c', 'it', 'i', 'z', 'en', 's', '▁of', '▁', 'N', 'e', 'w', '▁', 'Y', 'or', 'k', ':', '▁', '-', '▁The', '▁f', 'a', 'ct', 's', '▁with', '▁which', '▁I', '▁shall', '▁de', 'al', '▁this', '▁e', 'v', 'en', 'ing', '▁are', '▁m', 'ain', 'ly', '▁o', 'ld', '▁and', '▁f', 'am', 'il', 'i', 'ar', ';', '▁n', 'or', '▁is', '▁the', 're', '▁an', 'y', 'th', 'ing', '▁ne', 'w', '▁in', '▁the', '▁g', 'en', 'er', 'al', '▁u', 'se', '▁I', '▁shall', '▁m', 'ak', 'e', '▁of', '▁them', '.', '▁I

## Define our Helper Functions

In [6]:
import torch
import numpy as np

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [8]:
def _yield_CBOW_batch(text, batch_size, n_window, util_rate, vocab_size=tk_vocab_size, device=device, tokenizer=tokenizer):
    """
        Generator function that takes in text and returns a number of batches for each dataset based on the utilization rate specified. To be used in a CBOW model.

        Inputs:
            text:       (string) The text provided by the dataloader
            batch_size: (int) Number of samples to return in each batch
            n_window:   (int) Size of our context window for the CBOW model. I.e., if n_window=4, then we will use the left 4 words and right 4 words to predict our target word
            util_rate:  (float) Value from (0, 1] that specifies the % of possible batches that are generated before moving to next sample
            vocab_size: (int) size of our tokenizer vocabulary
            device:     Pytorch device (e.g., cuda / cpu)
            tokenizer:  Our defined tokenizer (above). encode_as_ids(text) returns a 1-D python list of tokens

        Yields (generator function) batches of data in the form of GPU-mounted pytorch tensors until util_rate is tripped.
    """
    tokens = torch.tensor(tokenizer.encode_as_ids(text), device=device)
    len_tokens = len(tokens)
    num_possible_pairs = len_tokens - (2 * n_window)
    num_batches = int((num_possible_pairs * util_rate) // batch_size)
    
    center_indices = torch.arange(n_window, len_tokens - n_window, device=device)
    center_indices = center_indices[torch.randperm(center_indices.size(0))][:num_batches*batch_size]
    
    for i in range(num_batches):
        batch_center_indices = center_indices[i*batch_size:(i+1)*batch_size]
        
        # Initialize the context and target tensors
        context_tensor = torch.zeros(batch_size, vocab_size, device=device)
        target_tensor = torch.zeros(batch_size, vocab_size, device=device)
        
        for idx, center_idx in enumerate(batch_center_indices):
            # For each center word index, create a context window
            context_indices = torch.cat((tokens[center_idx - n_window:center_idx], tokens[center_idx + 1:center_idx + 1 + n_window]))
            
            # Update the context_tensor for all context words (summing up one-hot vectors)
            for context_idx in context_indices:
                context_tensor[idx, context_idx] += 1
            normalized_context_tensor = context_tensor / (n_window*2)   # we want the sum of each context tensor to be 1 (so becomes average of the associated word vecs)
            
            # Update the target_tensor
            target_tensor[idx, tokens[center_idx]] = 1
        
        yield normalized_context_tensor, target_tensor

In [9]:
# Hyperparams
batch_size=8
n_window=4
util_rate=0.4

for batch, (context, target) in enumerate(_yield_CBOW_batch(sample_data, batch_size, n_window, util_rate, device=device, tokenizer=tokenizer)):
    print(f"Batch {batch+1}")
    print(f"{'-'*60}")
    print(context.sum())
    print(target.sum())
    break

Batch 1
------------------------------------------------------------
tensor(8.)
tensor(8.)


In [10]:
# Next - let's keep track of our loss data
def _estimate_loss(model, n_iters, device=device):
    """
        Function to estimate our loss (train and test) that we can call

        Inputs:
            model:   Pytorch sequential model
            n_iters: (int) Specify number of iterations to compute loss over
            device:  Pytorch device (cuda / cpu)
    """
    pass

## Import and Train our Model

In [11]:
# Import Pytorch Dependencies
import torch.nn as nn
from torch.nn import functional as F

In [12]:
# Implementation of a simple LBL model as laid out in https://www.cs.toronto.edu/~amnih/papers/hlbl_final.pdf (Introduces Hierarchal Softmax)
class LogBilinearModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, device):
        super().__init__()
        self.vocab_size = vocab_size
        self.device = device
        self.net = nn.Sequential(
            nn.Linear(vocab_size, embed_dim),
            nn.Linear(embed_dim, vocab_size)
        )
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, x, targets=None):
        """
            Single forward pass, takes in matrix 'x' [B x vocab_size] where 'x' is a multi-hot encoding relating to the tokens in the context window.
            Targets (optional for loss calculation) is a [B x V] vector
        """
        logits = self.net(x)  # [B x Vocab_size(V)] @ [V x V] -> [B x V]
        if targets is None:
            loss = None
        else:
            B, V = logits.shape
            logits_flat = logits.view(B*V)
            targets = targets.view(B*V)
            loss = F.cross_entropy(logits_flat, targets)
        return logits, loss

In [13]:
# Hyperparams
batch_size=8
n_window=4
util_rate=0.5
learning_rate = 3e-4
embed_dim = 128

# Define a simple training loop
def train_model(modelclass, train_dl, test_dl, tokenizer, device, saveweights=True):
    """
        Main training loop for our model that we specify
    """
    model = modelclass(tk_vocab_size, embed_dim, device)
    m = model.to(device)
    print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    loss_hist = []
    temp_loss_hist = []
    iter = 0

    for sample in train_dl:
        for (context, targets) in _yield_CBOW_batch(sample[0][0], batch_size, n_window, util_rate, tk_vocab_size, device=device, tokenizer=tokenizer):
            iter += 1
            logits, loss = model(context, targets)
            temp_loss_hist.append(loss)
            
            if iter % 1000 == 0:
                ave_loss = torch.tensor(temp_loss_hist).mean()
                loss_hist.append(ave_loss)
                temp_loss_hist = []
                if iter % 10000 == 0:
                    print(f"Iter: {iter} - Loss: {ave_loss:.3f}")

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            
        if iter>1e7:
            return loss_hist

In [14]:
# Running our training loop
loss = train_model(LogBilinearModel, train_dataloader, test_dataloader, tokenizer, device)

0.06592 M parameters
Iter: 10000 - Loss: 56.110
Iter: 20000 - Loss: 55.149
Iter: 30000 - Loss: 54.297
Iter: 40000 - Loss: 51.408
Iter: 50000 - Loss: 52.942
Iter: 60000 - Loss: 51.462
Iter: 70000 - Loss: 49.600
Iter: 80000 - Loss: 50.091
Iter: 90000 - Loss: 50.147
Iter: 100000 - Loss: 51.474
Iter: 110000 - Loss: 51.336
Iter: 120000 - Loss: 49.860
Iter: 130000 - Loss: 50.156
Iter: 140000 - Loss: 50.054
Iter: 150000 - Loss: 49.193
Iter: 160000 - Loss: 47.932
Iter: 170000 - Loss: 49.589
Iter: 180000 - Loss: 51.657
Iter: 190000 - Loss: 46.651
Iter: 200000 - Loss: 49.788
Iter: 210000 - Loss: 49.816
Iter: 220000 - Loss: 49.551
Iter: 230000 - Loss: 49.789
Iter: 240000 - Loss: 47.909
Iter: 250000 - Loss: 48.993
Iter: 260000 - Loss: 49.623
Iter: 270000 - Loss: 49.286
Iter: 280000 - Loss: 48.986
Iter: 290000 - Loss: 50.074
Iter: 300000 - Loss: 50.057
Iter: 310000 - Loss: 42.042
Iter: 320000 - Loss: 48.016
Iter: 330000 - Loss: 48.414
Iter: 340000 - Loss: 50.513
Iter: 350000 - Loss: 49.444
Iter: 36

KeyboardInterrupt: 