In [1]:
import os
import json
import pickle
from typing import List
from collections import Counter, OrderedDict
from itertools import chain

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

### Configuration

In [2]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using {device} as device.")

Using cpu as device.


In [26]:
data_dir = "./data"
model_dir = "./models"

In [3]:
debug = True # set to false for full training run

In [4]:
if debug:
    CONTEXT_WINDOW = 2 # the number of words on either side of target word
    EMBEDDING_SIZE = 5
    MIN_FREQ = 25 # dropping words that appear less than 5 times
    BATCH_SIZE = 3
    N_EPOCHS = 1
else:
    CONTEXT_WINDOW = 4 # the number of words on either side of target word
    EMBEDDING_SIZE = 100
    MIN_FREQ = 25 # dropping words that appear less than 5 times
    BATCH_SIZE = 64
    N_EPOCHS = 2

### Data Preparation

In [32]:
os.makedirs(data_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

In [None]:
#Extract Data
import requests

url = "https://www.gutenberg.org/cache/epub/7370/pg7370.txt"
response = requests.get(url)

with open("data.txt", "w", encoding="utf-8") as file:
    file.write(response.text)

Copying gs://adsp-nlp-open/data/word2vec_training_sentences.json...
==> NOTE: You are downloading one or more large file(s), which would            
run significantly faster if you enabled sliced object downloads. This
feature is enabled by default but requires that compiled crcmod be
installed (see "gsutil help crcmod").

/ [1 files][274.8 MiB/274.8 MiB]    1.2 MiB/s                                   
Operation completed over 1 objects/274.8 MiB.                                    


In [None]:
with open("data.txt", "r", encoding="utf-8") as file:
    text = file.read()


start = text.find("The Project Gutenberg eBook of Second Treatise of Government")
# Extract the text starting from the specified line
# No footer to remove in this case

sentences = text[start:]

# Split text into sentences based on periods and remove leading/trailing whitespace
sentences = [sentence.strip() for sentence in text.split('.') if sentence.strip()]

num_sentences = len(sentences)
print("Number of sentences:", num_sentences)

Number of sentences: 824,099


[['irish', 'league', 'cup'],
 ['sabre', 'engine', 'thrust', 'weight', 'ratio', 'up', 'to', 'atmospheric'],
 ['the',
  'recording',
  'was',
  'engineered',
  'by',
  'jim',
  'caruana',
  'and',
  'mixed',
  'by',
  'jason',
  'goldstein',
  'at',
  'sony',
  'music',
  'studios',
  'in',
  'new',
  'york',
  'city'],
 ['goldstein', 'was', 'hired', 'to', 'mix', 'b', 'day'],
 ['he', 'said', 'this', 'song', 'was', 'really', 'simple', 'to', 'mix']]

In [None]:
# Tokenize the sentences into words
# Sentence and word tokenize, clean, and lower case
from nltk.tokenize import word_tokenize, sent_tokenize
import re
sentences = sent_tokenize(text)
processed_sentences = []

for sentence in sentences:
    words = word_tokenize(sentence)
    words = [re.sub(r'\W+', '', word.lower()) for word in words if word.isalpha()]
    if words:
        processed_sentences.append(words)

# Build Vocabulary

In [9]:
class Vocab:
    def __init__(
        self,
        word_counts: OrderedDict, # vocabular is based on word counts
        min_freq: int = 1, # min times a word must appear in corpus (rare words might not be worth considering)
        max_size: int = None, # we can limit the amount of words as well 
        specials: List[str] = None, # any other special tokens we may want to add, like padding tokens
        unk_token: str = "<unk>" # reserved token for when we run into words not in the vocabulary
    ):
        self.word_counts = word_counts
        self.min_freq = min_freq
        self.max_size = max_size
        self.unk_token = unk_token
        self.specials = list(specials) if specials else []

        if self.unk_token not in self.specials:
            self.specials.insert(0, self.unk_token) # unknown token should always be included

        self.token2idx = {}
        self.idx2token = []

        self._prepare_vocab()


    def __len__(self):
        return len(self.idx2token)
    

    def __contains__(self, value):
        return value in self.idx2token


    def _prepare_vocab(self):
        """Processes input OrderedDict: Filters based on min_freq & adds special tokens."""
        vocab_list = self.specials.copy()  # Copy specials to avoid modifying original list

        # filter words based on min_freq and add to vocab
        filtered_words = [
            word
            for word, freq in self.word_counts.items()
            if freq >= self.min_freq and word not in self.specials
        ]

        # enforcing max vocab size constraint
        if self.max_size is not None:
            n_to_keep = self.max_size - len(self.specials) # special tokens take up spaces
            filtered_words = filtered_words[:n_to_keep]

        # creating final vocab list
        vocab_list.extend(word for word in filtered_words)

        # create look up tables
        self.idx2token = vocab_list
        self.token2idx = {word: idx for idx, word in enumerate(vocab_list)}


    def get_token(self, idx: int) -> str:
        """Returns the token corresponding to an index. Raises error if index is out of range."""
        if 0 <= idx < len(self.idx2token):
            return self.idx2token[idx]
        raise IndexError(f"Index {idx} is out of range for vocabulary size {len(self.idx2token)}")


    def get_index(self, token: str) -> int:
        """Returns the index corresponding to a token. Defaults to unk_token if missing."""
        return self.token2idx.get(token, self.token2idx[self.unk_token])  # return unk_token index if word is not in vocab


    def get_tokens(self, indices: List[int]) -> List[str]:
        """Converts a list of indices into a list of tokens."""
        return [self.get_token(idx) for idx in indices]


    def get_indices(self, tokens: List[str]) -> List[int]:
        """Converts a list of tokens into a list of indices."""
        return [self.get_index(token) for token in tokens]

In [10]:
def pad_sentences(sentences: List[List[str]], context_length: int, pad_token: str = "<pad>") -> List[List[str]]:
    """
    Pads each sentence to fit the context window length with the literal string "<pad>".
    
    Args:
        sentences: A list of sentences, where each sentence is a list of tokens.
        context_length: The number of tokens to either side of the target token.

    Returns:
        A list of padded sentences.
    """
    padded_sentences = []
    for sentence in sentences:
        padded_sentence = [pad_token] * context_length + sentence + [pad_token] * context_length
        padded_sentences.append(padded_sentence)
    
    return padded_sentences

In [None]:
sentences = pad_sentences(processed_sentences, CONTEXT_WINDOW)

In [12]:
sentences[:2]

[['<pad>', '<pad>', 'irish', 'league', 'cup', '<pad>', '<pad>'],
 ['<pad>',
  '<pad>',
  'sabre',
  'engine',
  'thrust',
  'weight',
  'ratio',
  'up',
  'to',
  'atmospheric',
  '<pad>',
  '<pad>']]

In [13]:
vocab = Vocab(
    word_counts=OrderedDict(Counter(chain.from_iterable(sentences))),
    min_freq=MIN_FREQ,
    specials=["<pad>"]
)

In [14]:
# creating a vocabulary
print(f"Size of Vocabulary: {len(vocab):,}")

Size of Vocabulary: 28,322


In [15]:
for idx in [0, 1, 5, 500, 10_000]:
    print(f"Index {idx} corresponds to `{vocab.get_token(idx)}`")

Index 0 corresponds to `<unk>`
Index 1 corresponds to `<pad>`
Index 5 corresponds to `sabre`
Index 500 corresponds to `insects`
Index 10000 corresponds to `dwarf`


Next we prepare the training data. For skip-gram model

In [None]:
import torch
from typing import List, Tuple

def generate_skipgram_pairs(sentences: List[List[str]], context_length: int, vocab) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate (center, context) pairs for Skip-gram model.

    Args:
        sentences: A list of sentences, where each sentence is a list of tokens.
        context_length: The number of tokens to either side of the center token.
        vocab: A vocab object that maps words to indices and vice versa.

    Returns:
        A tuple of two torch.Tensors:
            - centers: tensor of center word indices
            - contexts: tensor of corresponding context word indices
    """
    
    centers = []
    contexts = []

    for sentence in sentences:
        enc_sentence = vocab.get_indices(sentence)

        for center_idx in range(context_length, len(enc_sentence) - context_length):
            center = enc_sentence[center_idx]

            # Iterate over each context word around the center
            for i in range(-context_length, context_length + 1):
                if i == 0:
                    continue  # skip the center word itself
                context = enc_sentence[center_idx + i]
                centers.append(center)
                contexts.append(context)

    return torch.tensor(centers), torch.tensor(contexts)


In [None]:
centerss, contexts = generate_skipgram_pairs(sentences, CONTEXT_WINDOW, vocab)

print(f"Number of center-context pairs: {len(centerss):,}")
print(f"First 5 center indices: {centerss[:5]}")

Number of training examples: torch.Size([16409899])


In [18]:
# does does context look like?
print("contexts:", contexts, sep="\n")
print("contexts shape:", contexts.shape)

contexts:
tensor([[   1,    1,    3,    4],
        [   1,    2,    4,    1],
        [   2,    3,    1,    1],
        ...,
        [   1,    0, 1732,    1],
        [   0,  768,    1,    1],
        [   1,    1,    1,    1]])
contexts shape: torch.Size([16409899, 4])


In [19]:
# does do targets look like?
print("contexts:", targets)
print("contexts shape:", targets.shape)

contexts: tensor([   2,    3,    4,  ...,  768, 1732, 3001])
contexts shape: torch.Size([16409899])


In [20]:
# converting first context-target pair back to string
for idx in [6, 27, 1000]:
    print("context:", vocab.get_tokens(contexts[idx].tolist()))
    print("target:", vocab.get_tokens([targets[idx].item()]))
    print()

context: ['engine', 'thrust', 'ratio', 'up']
target: ['weight']

context: ['music', 'studios', 'new', 'york']
target: ['in']

context: ['to', 'favour', 'by', 'birds']
target: ['visits']



We also create a custome Dataset object so we can wrap it in a DataLoader object for batching, shuffling etc.

In [None]:
class NGramDataset(Dataset): # subclassing Dataset is required here
    
    def __init__(self, contexts, targets): # necessary method / function
        self.contexts = contexts
        self.targets = targets

    def __len__(self): # necessary method / function
        return len(self.contexts)

    def __getitem__(self, idx): # necessary method / function
        return self.contexts[idx], self.targets[idx]

### Create and Train Model

Next, we define the CBOW model.

In [None]:
model = NGramDataset(vocab_size=len(vocab)).to(device)
print(model)
print(f"Size of Vocabulary: {len(vocab):,}")