In [1]:
import time
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
from tqdm import tqdm
import pickle
import random
import numpy as np
from collections import Counter, defaultdict
import numpy as np
from torch import FloatTensor as FT

### Instructions
For this part, fill in the required code and make the notebook work. This wll be very similar to the Skip-Gram model, but a little more difficult. Look for the """ FILL IN """ string to guide you.

In [2]:
# Where do I want to run my job. You can do "cuda" on linux machines
# DEVICE = "mps" if torch.backends.mps.is_available() else  "cpu"
DEVICE = "cuda" if torch.cuda.is_available() else  "cpu"

# The batch size in Adam or SGD
BATCH_SIZE = 512

# Number of epochs
NUM_EPOCHS = 10

# Predict from 2 words the inner word for CBOW
# I.e. I'll have a window like ["a", "b", "c"] of continuous text (each is a word)
# We'll predict each of wc = ["a", "c"] from "b" = wc for Skip-Gram
# For CBOW, we'll use ["a", "c"] to predict "b" = wo
WINDOW = 1

# Negative samples.
K = 4

In [3]:
DEVICE

'cuda'

The text8 Wikipedia corpus. 100M characters.

In [4]:
from google.colab import drive
drive.mount('/content/drive')

!du -h text8

f = open('/content/drive/MyDrive/text8', 'r')
text = f.read()
# One big string of size 100M
print(len(text))

Mounted at /content/drive
du: cannot access 'text8': No such file or directory
100000000


In [5]:
punc = '!"#$%&()*+,-./:;<=>?@[\\]^_\'{|}~\t\n'

# Can do regular expressions here too
for c in punc:
    if c in text:
        text.replace(c, ' ')

In [7]:
def get_tokenizer(text):
  return text.lower().split()

# A very crude tokenizer you get for free: lower case and also split on spaces
# This will not work!
# Split text on space and strip each word
# You should get a list "words" which is text but each element is a word in order
TOKENIZER = get_tokenizer

' FILL IN '

In [8]:
words = TOKENIZER(text)
f = Counter(words)
# Fix the above

In [9]:
len(words)

17005207

In [10]:
# Do a very crude filter on the text which removes all very popular words
text = [word for word in words if f[word] > 5]

In [11]:
text[0:5]

['anarchism', 'originated', 'as', 'a', 'term']

In [13]:
class Vocab:
  def __init__(self, itos):
    self.itos = itos
    self.stoi = {token: idx for idx, token in enumerate(itos)}

  def get_stoi(self):
    return self.stoi

  def get_itos(self):
    return self.itos

  def __len__(self):
    return len(self.itos)

def build_vocab_from_iterator(iterator):
    counter = Counter()
    for tokens in iterator:
        counter.update(tokens)

    itos = []
    for token, _ in counter.most_common():
        if token not in itos:
            itos.append(token)

    return Vocab(itos)

VOCAB = build_vocab_from_iterator([text])
# Rebuild the vocabulary from text above

In [14]:
# Populate these maps manually using Counter or defaltdict
# This will not work

# word -> int hash map
stoi = VOCAB.get_stoi()
# int -> word hash map
itos = VOCAB.get_itos()

In [15]:
stoi['as']

11

In [16]:
# Total number of words; you should see about 63K as below
len(stoi)

63641

In [17]:
f = Counter(text)
# This is the probability that we pick a word in the corpus
z = {word: f[word] / len(text) for word in f}

In [18]:
threshold = 1e-5
# Probability that word is kept while subsampling
# This is explained here and sightly differet from the paper: http://mccormickml.com/2017/01/11/word2vec-tutorial-part-2-negative-sampling/
p_keep = {word: (np.sqrt(z[word] / 0.001) + 1)*(0.0001 / z[word]) for word in f}

In [19]:
# This is in the integer space
train_dataset = [word for word in text if random.random() < p_keep[word]]

# Rebuild the vocabulary
VOCAB = build_vocab_from_iterator([train_dataset])

In [20]:
len(train_dataset)

7845676

In [21]:
# Repopulate the stoi and itos maps again now that you dropped some words

# word -> int mapping
stoi = VOCAB.get_stoi()
# int -> word mapping
itos = VOCAB.get_itos()

In [22]:
# The vocabulary size after we do all the filters
len(VOCAB)

63641

In [23]:
# The probability we draw something for negative sampling
f = Counter(train_dataset)
p = torch.zeros(len(VOCAB))

# Downsample frequent words and upsample less frequent
s = sum([np.power(freq, 0.75) for word, freq in f.items()])

for word in f:
    p[stoi[word]] = np.power(f[word], 0.75) / s

In [24]:
# Map everything to integers
# This might not work be careful w the above ...
train_dataset = [stoi[word] for word in text]

In [45]:
# This just gets the (wc, wo) pairs that are positive - they are seen together!
def get_tokenized_dataset(dataset, verbose=False):
    x_list = []

    for i, token in enumerate(dataset):
        m = 1

        # Get the left and right tokens
        start = max(0, i - m)
        left_tokens = dataset[start: i]

        end = min(i + m, len(dataset) - 1)
        right_tokens = dataset[i + 1: end + 1]

        # Check these are the same length, and if so use them to add a row of data. This should be a list like
        # [a, c, b] where b is the center word
        if len(left_tokens) == len(right_tokens):
            w_context = left_tokens + right_tokens

            wc = [token]

            x_list.extend(
                [w_context + wc]
            )

    return x_list

In [46]:
train_x_list = get_tokenized_dataset(train_dataset, verbose=False)

In [48]:
pickle.dump(train_x_list, open('train_x_list.pkl', 'wb'))

In [49]:
train_x_list = pickle.load(open('train_x_list.pkl', 'rb'))

In [50]:
# These are (wc, wo) pairs. All are y = +1 by design
train_x_list[:10]

[[5233, 11, 3080],
 [3080, 6, 11],
 [11, 164, 6],
 [6, 1, 164],
 [164, 3133, 1],
 [1, 46, 3133],
 [3133, 60, 46],
 [46, 177, 60],
 [60, 123, 177],
 [177, 741, 123]]

In [51]:
# The number of things of BATCH_SIZE = 512
assert(len(train_x_list) // BATCH_SIZE == 32579)

'cuda'

### Set up the dataloader.

In [52]:
train_dl = DataLoader(
    TensorDataset(
        torch.tensor(train_x_list).to(DEVICE),
    ),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [54]:
for xb in train_dl:
    assert(xb[0].shape == (BATCH_SIZE, 3))
    break

### Words we'll use to asses the quality of the model ...

In [55]:
valid_ids = torch.tensor([
    stoi['money'],
    stoi['lion'],
    stoi['africa'],
    stoi['musician'],
    stoi['dance'],
])

### Get the model.

In [65]:
class CBOWNegativeSampling(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(CBOWNegativeSampling, self).__init__()
        self.A = nn.Embedding(vocab_size, embed_dim) # Context vectors - center word
        self.B = nn.Embedding(vocab_size, embed_dim) # Output vectors - words around the center word
        self.init_weights()

    def init_weights(self):
        # Is this the best way? Not sure
        initrange = 0.5
        self.A.weight.data.uniform_(-initrange, initrange)
        self.B.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        # N is the batch size
        # x is (N, 3)

        # Context words are 2m things, m = 1 so w_context is (N, 2) while wc is (N, 1)
        w_context, wc = x[:, : 2], x[:, 2: 3]

        # Each of these is (N, 2, D) since each context has 2 word
        # We want this to be (N, D) and this is what we get

        # (N, 2, D)
        a = self.A(w_context)

        # (N, D)
        a_avg = a.mean(dim = 1).squeeze(dim = 1)

        # Each of these is (N, D) since each target has 1 word
        b = self.B(wc).squeeze(dim = 1)

        # The product between each context and target vector. Look at the Skip-Gram code.
        # The logits is now (N, 1) since we sum across the final dimension.
        logits = (a_avg*b).sum(dim=1)

        return logits

In [66]:
@torch.no_grad()
def validate_embeddings(
    model,
    valid_ids,
    itos
):
    """ Validation logic """

    # We will use context embeddings to get the most similar words
    # Other strategies include: using target embeddings, mean embeddings after avaraging context/target
    embedding_weights = model.A.weight

    normalized_embeddings = embedding_weights.cpu() / np.sqrt(
        np.sum(embedding_weights.cpu().numpy()**2, axis=1, keepdims=True)
    )

    # Get the embeddings corresponding to valid_term_ids
    valid_embeddings = normalized_embeddings[valid_ids, :]

    # Compute the similarity between valid_term_ids (S) and all the embeddings (V)
    # We do S x d (d x V) => S x D and sort by negative similarity
    top_k = 10 # Top k items will be displayed
    similarity = np.dot(valid_embeddings.cpu().numpy(), normalized_embeddings.cpu().numpy().T)

    # Invert similarity matrix to negative
    # Ignore the first one because that would be the same word as the probe word
    similarity_top_k = np.argsort(-similarity, axis=1)[:, 1: top_k+1]

    # Print the output.
    for i, word_id in enumerate(valid_ids):
        # j >= 1 here since we don't want to include the word itself.
        similar_word_str = ', '.join([itos[j] for j in similarity_top_k[i, :] if j >= 1])
        # This might need a fix!
        print(f"{itos[word_id]}: {similar_word_str}")

    print('\n')

### Set up the model

In [67]:
LR = 10.0
NUM_EPOCHS = 10
EMBED_DIM = 300

In [68]:
model = CBOWNegativeSampling(len(VOCAB), EMBED_DIM).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

# The learning rate is lowered every epoch by 1/10
# Is this a good idea?
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)

In [69]:
model

CBOWNegativeSampling(
  (A): Embedding(63641, 300)
  (B): Embedding(63641, 300)
)

In [70]:
validate_embeddings(model, valid_ids, itos)

money: salicylate, admittedly, faro, reactivation, cyclase, cvd, woo, piper, surgical, jfk
lion: branwell, plough, refurbished, virginis, politician, relaxed, erlk, melbourne, flutes, trento
africa: jpeg, cronus, rstendamm, heures, rachid, pythons, bolyai, cassettes, bergman, dichromate
musician: jaws, boar, rochester, circuit, representatives, taxing, commemorates, beverages, electrotechnical, maia
dance: substantiated, houghton, asks, darryl, magick, honours, fatimid, solzhenitsyn, falange, commotion




  normalized_embeddings = embedding_weights.cpu() / np.sqrt(


### Train the model

In [75]:
ratios = []

def train(dataloader, model, optimizer, epoch):
    model.train()
    total_acc, total_count, total_loss, total_batches = 0, 0, 0.0, 0.0
    log_interval = 500

    for idx, x_batch in tqdm(enumerate(dataloader)):

        x_batch = x_batch[0]

        batch_size = x_batch.shape[0]

        # Zero the gradient so they don't accumulate
        optimizer.zero_grad()

        logits = model(x_batch)

        # Get the positive samples loss. Notice we use weights here
        positive_loss = torch.nn.BCEWithLogitsLoss()(input=logits, target=torch.ones(batch_size).to(DEVICE).float())

        # For each batch, get some negative samples
        # We need a total of len(y_batch) * K samples across a batch
        # We then reshape this batch
        # These are effectively the output words
        negative_samples = torch.multinomial(p, batch_size * K, replacement=True).to(DEVICE)

        # Context words are 2m things, m = 1 so w_context is (N, 2) while wc is (N, 1)
        w_context, wc = x_batch[:, :2], x_batch[:, 2: 3]

        """
        if w_context looks like below (batch_size = 3)
        [
        (a, b),
        (c, d),
        (e, f)
        ] and K = 2 we'd like to get:

        [
        (a, b),
        (a, b),
        (c, d),
        (c, d),
        (e, f),
        (e, f)
        ]

        This will be batch_size * K rows.
        """

        # This should be (N * K, 2)
        w_context = torch.concat([
            w.repeat(K, 1) for w in torch.tensor(w_context).split(1)
        ]).to(DEVICE)

        wc = negative_samples[:, None]

        # Get the negative samples. This should be (N * K, 3)
        # Concatenate the w_context and wc along the column. Make sure everything is on CUDA / MPS or CPU
        x_batch_negative = torch.cat([w_context, wc], dim = 1).to(DEVICE)

        """
        Note the way we formulated the targets: they are all 0 since these are negative samples.
        We do the BCEWithLogitsLoss by hand basically here.
        Notice we sum across the negative samples, per positive word.

        This is literally the equation in the lecture notes.
        """

        # (N, K, D) -> (N, D) -> (N)
        # Look at the Skip-Gram notebook
        logits_negative = -model(x_batch_negative)
        negative_loss = torch.nn.BCEWithLogitsLoss()(input=logits_negative, target=torch.ones(batch_size * K).to(DEVICE).float())

        loss = (positive_loss + negative_loss).mean()

        # Get the gradients via back propagation
        loss.backward()

        # Clip the gradients? Generally a good idea
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)


        # Do an optimization step. Update the parameters A and B
        optimizer.step()
        # Get the new loss
        total_loss += loss.item()

        # Update the batch count
        total_batches += 1

        if idx % log_interval == 0:
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| loss {:8.3f} ".format(
                    epoch,
                    idx,
                    len(dataloader),
                    total_loss / total_batches
                )
            )
            validate_embeddings(model, valid_ids, itos)
            total_loss, total_batches = 0.0, 0.0

### Some results from the run look like below:

Somewhere inside of 2 iterations you should get sensible associattions.
Paste here a screenshot of the closest vectors.

In [None]:
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time()

    train(train_dl, model, optimizer, epoch)
    # We have a learning rate scheduler here

    # Basically, given the state of the optimizer, this lowers the learning rate in a smart way
    scheduler.step()

  w.repeat(K, 1) for w in torch.tensor(w_context).split(1)
  normalized_embeddings = embedding_weights.cpu() / np.sqrt(
1it [00:02,  2.14s/it]

| epoch   1 |     0/32580 batches | loss    1.568 
money: salicylate, admittedly, faro, reactivation, cyclase, cvd, woo, piper, surgical, jfk
lion: branwell, plough, refurbished, virginis, politician, relaxed, erlk, melbourne, flutes, trento
africa: jpeg, cronus, rstendamm, heures, rachid, pythons, bolyai, bergman, cassettes, dichromate
musician: jaws, boar, rochester, circuit, representatives, taxing, commemorates, beverages, electrotechnical, maia
dance: substantiated, houghton, asks, darryl, magick, honours, fatimid, solzhenitsyn, falange, commotion




500it [00:11, 51.00it/s]

| epoch   1 |   500/32580 batches | loss    1.360 
money: salicylate, reactivation, admittedly, cyclase, faro, piper, woo, jem, opportunistic, domesticated
lion: branwell, plough, refurbished, relaxed, virginis, politician, erlk, melbourne, trento, flutes
africa: jpeg, cronus, rachid, cassettes, rstendamm, heures, splicing, bolyai, dichromate, extensions
musician: jaws, boar, circuit, rochester, representatives, taxing, maia, electrotechnical, commemorates, beverages
dance: substantiated, houghton, asks, darryl, magick, honours, german, falange, fatimid, solzhenitsyn




1000it [00:21, 51.74it/s]

| epoch   1 |  1000/32580 batches | loss    1.254 
money: salicylate, cyclase, piper, reactivation, jem, admittedly, cvd, faro, opportunistic, domesticated
lion: branwell, plough, refurbished, relaxed, politician, virginis, erlk, melbourne, trento, flutes
africa: jpeg, cronus, cassettes, rachid, white, norville, entertaining, splicing, dichromate, rstendamm
musician: jaws, rochester, boar, circuit, representatives, taxing, electrotechnical, maia, beverages, commemorates
dance: substantiated, houghton, darryl, asks, magick, german, honours, commotion, thirsty, solzhenitsyn




1496it [00:31, 51.42it/s]

| epoch   1 |  1500/32580 batches | loss    1.208 
money: salicylate, cyclase, jem, piper, reactivation, opportunistic, poison, cvd, domesticated, jfk
lion: branwell, refurbished, plough, relaxed, erlk, virginis, politician, melbourne, trento, flutes
africa: jpeg, white, cronus, norville, rachid, entertaining, cassettes, splicing, extensions, bru
musician: jaws, rochester, circuit, boar, representatives, taxing, maia, electrotechnical, esat, subgenera
dance: houghton, substantiated, darryl, asks, magick, honours, german, shareholders, fatimid, recalls




1997it [00:41, 52.19it/s]

| epoch   1 |  2000/32580 batches | loss    1.177 
money: salicylate, piper, cyclase, jem, opportunistic, reactivation, domesticated, poison, cvd, jfk
lion: branwell, refurbished, plough, relaxed, erlk, virginis, melbourne, politician, trento, flutes
africa: jpeg, norville, white, cronus, rachid, entertaining, cassettes, bru, imparts, extensions
musician: rochester, jaws, boar, circuit, representatives, taxing, maia, electrotechnical, esat, subgenera
dance: substantiated, houghton, darryl, asks, magick, honours, german, shareholders, commotion, thirsty




2500it [00:51, 51.43it/s]

| epoch   1 |  2500/32580 batches | loss    1.150 
money: salicylate, piper, reactivation, cyclase, opportunistic, domesticated, jem, poison, cvd, jfk
lion: branwell, refurbished, plough, relaxed, erlk, melbourne, politician, virginis, trento, flutes
africa: jpeg, white, norville, entertaining, cronus, rachid, cassettes, imparts, bru, bolyai
musician: rochester, jaws, circuit, boar, representatives, taxing, maia, esat, electrotechnical, subgenera
dance: substantiated, houghton, asks, german, darryl, magick, recalls, gospel, honours, shareholders




2998it [01:01, 51.58it/s]

| epoch   1 |  3000/32580 batches | loss    1.126 
money: salicylate, piper, reactivation, cyclase, opportunistic, jem, domesticated, poison, cvd, subsidies
lion: branwell, refurbished, plough, relaxed, melbourne, erlk, virginis, politician, trento, flutes
africa: jpeg, norville, white, entertaining, cronus, imparts, bru, rachid, cassettes, extensions
musician: rochester, jaws, boar, representatives, circuit, taxing, maia, esat, electrotechnical, subgenera
dance: substantiated, houghton, german, asks, darryl, recalls, magick, gospel, commotion, fatimid




3500it [01:11, 51.98it/s]

| epoch   1 |  3500/32580 batches | loss    1.104 
money: salicylate, piper, poison, cyclase, opportunistic, reactivation, domesticated, jem, cvd, subsidies
lion: branwell, refurbished, plough, relaxed, melbourne, erlk, virginis, trento, politician, flutes
africa: white, jpeg, norville, entertaining, city, cronus, imparts, bru, extensions, rachid
musician: rochester, representatives, circuit, boar, jaws, taxing, maia, esat, subgenera, electrotechnical
dance: substantiated, houghton, asks, recalls, darryl, german, gospel, magick, thirsty, commotion




3998it [01:21, 52.69it/s]

| epoch   1 |  4000/32580 batches | loss    1.084 
money: salicylate, piper, poison, jem, opportunistic, domesticated, reactivation, cyclase, subsidies, cvd
lion: branwell, refurbished, relaxed, plough, erlk, melbourne, virginis, politician, flutes, trento
africa: white, norville, jpeg, entertaining, city, cronus, imparts, bru, splicing, allocating
musician: rochester, representatives, boar, circuit, jaws, maia, esat, taxing, subgenera, electrotechnical
dance: substantiated, houghton, asks, recalls, gospel, darryl, hell, german, thirsty, shareholders




4501it [01:31, 34.44it/s]

| epoch   1 |  4500/32580 batches | loss    1.067 
money: salicylate, piper, opportunistic, poison, domesticated, jem, subsidies, cyclase, reactivation, cvd
lion: branwell, refurbished, relaxed, plough, erlk, melbourne, politician, virginis, flutes, itosu
africa: white, city, norville, jpeg, entertaining, cronus, bru, imparts, flanged, splicing
musician: rochester, representatives, jaws, esat, boar, maia, circuit, taxing, subgenera, trophyless
dance: substantiated, houghton, gospel, asks, recalls, german, hell, darryl, thirsty, shareholders




4998it [01:41, 51.10it/s]

| epoch   1 |  5000/32580 batches | loss    1.050 
money: salicylate, piper, opportunistic, domesticated, poison, subsidies, jem, cyclase, cvd, discs
lion: branwell, refurbished, relaxed, plough, melbourne, erlk, politician, virginis, itosu, armaments
africa: white, city, entertaining, norville, jpeg, cronus, bru, imparts, example, splicing
musician: rochester, representatives, circuit, jaws, esat, maia, boar, taxing, subgenera, electrotechnical
dance: substantiated, gospel, houghton, asks, recalls, german, hell, darryl, thirsty, shareholders




5004it [01:41, 33.32it/s]