# Assignment 1.4: Negative sampling (15 points)

You may have noticed that word2vec is really slow to train. Especially with big (> 50 000) vocabularies. Negative sampling is the solution.

The task is to implement word2vec with negative sampling.

This is what was discussed in Stanford lecture. The main idea is in the formula:

$$ L = \log\sigma(u^T_o \cdot u_c) + \sum^k_{i=1} \mathbb{E}_{j \sim P(w)}[\log\sigma(-u^T_j \cdot u_c)]$$

Where $\sigma$ - sigmoid function, $u_c$ - central word vector, $u_o$ - context (outside of the window) word vector, $u_j$ - vector or word with index $j$.

The first term calculates the similarity between positive examples (word from one window)

The second term is responsible for negative samples. $k$ is a hyperparameter - the number of negatives to sample.
$\mathbb{E}_{j \sim P(w)}$
means that $j$ is distributed accordingly to unigram distribution.

Thus, it is only required to calculate the similarity between positive samples and some other negatives. Not across all the vocabulary.

Useful links:
1. [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/pdf/1301.3781.pdf)
1. [Distributed Representations of Words and Phrases and their Compositionality](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf)

In [140]:
import gc
import string
import re
from collections import Counter
import numpy as np
gc.collect()
import nltk
from nltk.corpus import stopwords
STOP_WORDS = set(stopwords.words('english'))
len(STOP_WORDS)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(1)
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
import numpy

USE_GPU = True
dtype = torch.float32 # we will be using float throughout this tutorial
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print (torch.cuda.get_device_name(0))

GeForce GTX 1050 Ti


_______________________________

In [81]:
class Batcher:
    def __init__(self, max_len, window_size, corpus_path, min_freq, max_freq, max_voc_size, batch_size):
        self.corpus_path = corpus_path
        self.window_size = window_size
        self.min_freq = min_freq
        self.max_freq = max_freq
        self.max_voc_size = max_voc_size
        self.batch_size = batch_size
        self.max_len = max_len
        self.words = None
        self.word2index = None
        self.index2word = None
        self.freq = None
        self.voc = None
        self.voc_size = None
        self.corpus = None
        self.corpus_size = None
        
        
    def read_data(self, S):
        if S == None:
            with open(self.corpus_path, 'r') as f:
                S = f.read()
            if S!=None:
                S = S.lower()[: self.max_len]
        print('Len of S = ', len(S))
        regex = re.compile('[%s]' % re.escape(string.punctuation))
        S = regex.sub(' ', S)
        words_raw = list(S.split())
        print(len(words_raw))
        words = []
        for word in words_raw:
            if word in STOP_WORDS:
                pass
            else:
                words.append(word)

        print('Size of words = ', len(words))
        counter = Counter(words)
        print('Size of counter = ', len(counter))
        if self.min_freq != None:
            counter = {x : counter[x] for x in counter if counter[x] >= self.min_freq}
        print('Size of counter after min_freq = ', len(counter))
        if self.max_freq != None:
            counter = {x : counter[x] for x in counter if counter[x] <= self.max_freq}
        print('Size of counter after max_freq = ', len(counter))
        counter = Counter(counter)

        freq = dict(counter.most_common(self.max_voc_size))
        voc = set(freq)
        
        unk = set(words).difference(voc)
        print('Size of freq dict = ', len(voc))
        print('Number of vocabulary words = ', len(voc))
        print('Number of unknown words = ', len(unk))

        words = ['UNK' if word in unk else word for word in words]        
        if len(words)%self.batch_size == 0:
            padding = self.window_size
        else:
            padding = self.batch_size - len(words)%self.batch_size + self.window_size
            
        words = ['PAD']*self.window_size + words + ['PAD']*padding
        unique_words = list(set(words))
        print('Size of corpus = ', len(words))
        print('Size of vocabulary = ', len(unique_words))
        self.word2index = {k: v for v, k in enumerate(unique_words)}
        self.index2word = {v: k for v, k in enumerate(unique_words)}
        words = [self.word2index[word] for word in words]
        self.freq = Counter(words)
        self.voc = set(self.freq)
        self.voc_size = len(self.voc)
        self.corpus = words
        self.corpus_size = len(words)
    
    def generator(self):
        i = self.window_size
        x_batch = []
        y_batch = []
        
        while i < self.corpus_size-self.window_size:
            if len(x_batch)==self.batch_size:
                x_batch = []
                y_batch = []
                
            x = self.corpus[i-self.window_size: i] + self.corpus[i+1: i+self.window_size+1]
#             y = [0]*self.voc_size
#             y[self.corpus[i]] = 1
            y = [self.corpus[i]]
            x_batch.append(x)
            y_batch.append(y)
            i += 1
            if len(x_batch)==self.batch_size:
                yield np.array(x_batch), np.array(y_batch)

In [149]:
BATCH_SIZE = 8
MAX_LEN = 10000000
batcher = Batcher(max_len=MAX_LEN, window_size=2, corpus_path='text8', min_freq=5, max_freq=None, max_voc_size=10000000, batch_size=BATCH_SIZE)
batcher.read_data(S=None)

Len of S =  10000000
1706282
Size of words =  1090922
Size of counter =  70835
Size of counter after min_freq =  19359
Size of counter after max_freq =  19359
Size of freq dict =  19359
Number of vocabulary words =  19359
Number of unknown words =  51476
Size of corpus =  1090932
Size of vocabulary =  19361


In [150]:
for x, y in batcher.generator():
    print(x.shape, y.shape)
    break

(8, 4) (8, 1)


In [151]:
for x, y in batcher.generator():
    for i in range(x.shape[0]):
        target_word = y[i][0]
        for j in range(x.shape[1]):
            context_word = x[i][j]
            print(batcher.index2word[target_word], batcher.index2word[context_word])
    break

anarchism PAD
anarchism PAD
anarchism originated
anarchism term
originated PAD
originated anarchism
originated term
originated abuse
term anarchism
term originated
term abuse
term first
abuse originated
abuse term
abuse first
abuse used
first term
first abuse
first used
first early
used abuse
used first
used early
used working
early first
early used
early working
early class
working used
working early
working class
working radicals


In [159]:
class CBOW(nn.Module):
    def __init__(self, voc_size, embedding_dim, window_size, batch_size):
        super(CBOW, self).__init__()
        self.embedding1 = nn.Embedding(voc_size, embedding_dim)
        self.embedding2 = nn.Embedding(voc_size, embedding_dim)
        
    def forward(self, target_word, context_word):

        target_word = torch.tensor(target_word).to(device='cuda')
        context_word = torch.tensor(context_word).to(device='cuda')
        
        target_emb = self.embedding1(target_word)
        context_emb = self.embedding2(context_word)
        
        z1 = torch.mul(target_emb, context_emb)
        z2 = torch.sum(z1)
        pos_loss = F.logsigmoid(z2)
        
        neg_loss = 0
        for i in range(5):
            negative_word = torch.tensor(numpy.random.choice(batcher.corpus)).to(device='cuda')
            negative_emb = self.embedding2(negative_word)
            z4 = torch.mul(target_emb, negative_emb)
            z5 = torch.sum(z4)
            z6 = F.logsigmoid(-z5)
            neg_loss += z6

        print('target_emb shape : ', target_emb.shape)
        print('context_emb shape : ', context_emb.shape)
        print('negative_emb shape : ', negative_emb.shape)
        print('z1 shape : ', z1.shape)
        print('z2 : ', z2)
        
        return -(pos_loss + neg_loss)

In [160]:
### ReduceLROnPlateau

losses = []
#loss_function = nn.NLLLoss()
model = CBOW(voc_size=batcher.voc_size, embedding_dim=256, window_size=batcher.window_size, batch_size=batcher.batch_size)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.01)
lr_scheduler = ReduceLROnPlateau(optimizer = optimizer, \
                                 mode = 'min', \
                                 factor = 0.5, \
                                 threshold = 0.001 \
                                )

for epoch in [0, 1, 2]:
    print('========== Epoch {} =========='.format(epoch))
    total_loss = 0
    i = 1
    N = int(len(batcher.corpus)//BATCH_SIZE)
    for context, target in batcher.generator():
        batch_loss = 0
        model.train()
#         context = torch.tensor(context).to(device='cuda')
#         target = torch.tensor(target).to(device='cuda')
        for k in range(context.shape[0]):
            target_word = target[k][0]
            for j in range(context.shape[1]):
                context_word = context[k][j]
                #print(batcher.index2word[target_word], batcher.index2word[context_word])
        
                loss = model(target_word, context_word)

                optimizer.zero_grad()
                model.zero_grad()
                loss.backward()
                optimizer.step()

                #print(loss)
                batch_loss += loss
                #print('\n')
        print('Batch {} loss : {}'.format(i, batch_loss/(BATCH_SIZE*batcher.window_size*2)))
        losses.append(batch_loss)
        i += 1
    break

Batch 1 loss : 8.554247856140137
Batch 2 loss : 5.681187629699707
Batch 3 loss : 7.376816749572754
Batch 4 loss : 6.6656365394592285
Batch 5 loss : 7.099233627319336
Batch 6 loss : 4.026013374328613
Batch 7 loss : 9.054348945617676
Batch 8 loss : 7.0196990966796875
Batch 9 loss : 3.605301856994629
Batch 10 loss : 4.993880748748779
Batch 11 loss : 7.086320877075195
Batch 12 loss : 8.006959915161133
Batch 13 loss : 7.145155429840088
Batch 14 loss : 3.966289520263672
Batch 15 loss : 5.272380352020264
Batch 16 loss : 4.349819660186768
Batch 17 loss : 6.3697991371154785
Batch 18 loss : 6.144553184509277
Batch 19 loss : 5.561038970947266
Batch 20 loss : 6.398571014404297
Batch 21 loss : 3.7784817218780518
Batch 22 loss : 8.237434387207031
Batch 23 loss : 4.421940326690674
Batch 24 loss : 4.251057147979736
Batch 25 loss : 7.601797103881836
Batch 26 loss : 5.603153228759766
Batch 27 loss : 8.381185531616211
Batch 28 loss : 4.720135688781738
Batch 29 loss : 8.252762794494629
Batch 30 loss : 6.1

KeyboardInterrupt: 