In [2]:
### author: Chen Zheng
### Date: 10/13/2022

In [3]:
import numpy as np
import torch
from torch import nn
import torch.utils.data as Data
import random

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
##########################################################################################################################
# download the Penn Treebank dataset: https://deepai.org/dataset/penn-treebank
##########################################################################################################################
file = open('ptb.train.txt', 'r')
lines = file.readlines()
dataset = [sentence.split() for sentence in lines]
file.close()
print('ptb.train.txt has {} sentences.'.format(len(dataset)))


ptb.train.txt has 42068 sentences.


In [5]:
import collections
word_and_frequency_dict = collections.Counter([word for sentence in dataset for word in sentence]) ## key: word.  value: word appearimg times in the dataset
print('Before filtering, the vocabulary size of The dataset is: ', len(word_and_frequency_dict.keys()))
word_and_frequency_dict = dict(filter(lambda word: word[1] >= 5, word_and_frequency_dict.items())) ## select the words that appears at least 5 times.
print('After filtering, the vocabulary size of The dataset is: ', len(word_and_frequency_dict.keys()))

Before filtering, the vocabulary size of The dataset is:  9999
After filtering, the vocabulary size of The dataset is:  9858


In [6]:
##########################################################################################################################
# create word index table
##########################################################################################################################
index_to_word = list(word_and_frequency_dict.keys())
word_to_index = dict()
for index, word in enumerate(index_to_word):
    word_to_index[word] = index
# print(word_to_index)


In [7]:
##########################################################################################################################
# dataset words: string to index
##########################################################################################################################
dataset = [[word_to_index[word] for word in sentence if word in word_to_index] for sentence in dataset]

In [8]:
##########################################################################################################################
# extract center_word, and surrounding_words.
# e.g.: Hello world I love pytorch.  window size = 2
# center_word = 'hello', surrounding_words = ['world', 'I']
# center_word = 'world', surrounding_words = ['hello', 'I']
# center_word = 'I', surrounding_words = ['hello', 'world', 'love', 'pytorch']
# center_word = 'love', surrounding_words = ['world', 'I', 'pytorch']
# center_word = 'pytorch', surrounding_words = ['I', 'love']
##########################################################################################################################
window_size = 3 ### tips: you can try window_size = 4, 5, ...
center_word_list, surrounding_words_list = [], []
for sentence in dataset:
    if len(sentence) < 2:  
        continue
    center_word_list += sentence
    for center_word_i in range(len(sentence)):
        ### before: at most window_size, after: at most window_size. max total: window_size + window_size for each center word
        indices = list(range(max(0, center_word_i - window_size), min(len(sentence), center_word_i + 1 + window_size)))
        indices.remove(center_word_i) ### 'I'
        surrounding_words_list.append([sentence[index] for index in indices]) ### ['hello', 'world', 'love', 'pytorch']
print('We have {} center words.'.format(len(center_word_list)))
print('We have {} surrounding word list.'.format(len(surrounding_words_list)))
example_id = 200
for i in range(10):
    print('example ', example_id+i, ': center word: ', center_word_list[example_id+i], ' surrounding words: ', surrounding_words_list[example_id+i])

We have 886963 center words.
We have 886963 surrounding word list.
example  200 : center word:  7  surrounding words:  [121, 1, 122, 57, 123, 124]
example  201 : center word:  57  surrounding words:  [1, 122, 7, 123, 124, 7]
example  202 : center word:  123  surrounding words:  [122, 7, 57, 124, 7, 51]
example  203 : center word:  124  surrounding words:  [7, 57, 123, 7, 51, 88]
example  204 : center word:  7  surrounding words:  [57, 123, 124, 51, 88, 125]
example  205 : center word:  51  surrounding words:  [123, 124, 7, 88, 125, 17]
example  206 : center word:  88  surrounding words:  [124, 7, 51, 125, 17, 113]
example  207 : center word:  125  surrounding words:  [7, 51, 88, 17, 113, 126]
example  208 : center word:  17  surrounding words:  [51, 88, 125, 113, 126, 127]
example  209 : center word:  113  surrounding words:  [88, 125, 17, 126, 127, 128]


In [9]:
#########################################################################################################################
# negative sampling. 
# for each surrounding_words, we randomly add some 'noises'.
#########################################################################################################################
K = 5 ## for each center word, we construct 1 positive and k negative surrounding word list pairs
sampling_weights = [word_and_frequency_dict[word]**0.75 for word in index_to_word]
negative_samplings_list = []
negative_candidates = []
count = 0
population = list(range(len(sampling_weights)))
for cur_sur_words in surrounding_words_list:
    negatives = []
    while len(negatives) < len(cur_sur_words) * K:
        if count == len(negative_candidates):
            negative_candidates = random.choices(population, sampling_weights, k=int(1e5))
            count = 0
        neg= negative_candidates[count]
        count = count + 1
        if neg not in set(cur_sur_words):
            negatives.append(neg)
    negative_samplings_list.append(negatives) 

In [10]:
##########################################################################################################################
# dataset preprocessing
##########################################################################################################################
class W_2_V_Data(torch.utils.data.Dataset):
    def __init__(self, center_word_list, surrounding_words_list, negative_samplings_list):
        self.center_word_list = center_word_list
        self.surrounding_words_list = surrounding_words_list
        self.negative_samplings_list = negative_samplings_list

    def __getitem__(self, index):
        return (self.center_word_list[index], self.surrounding_words_list[index], self.negative_samplings_list[index])

    def __len__(self):
        return len(self.center_word_list)

def set_up_batch_data(data):
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),
            torch.tensor(masks), torch.tensor(labels))

batch_size = 128
w2v_data = W_2_V_Data(center_word_list, surrounding_words_list, negative_samplings_list)
data_loader = Data.DataLoader(w2v_data, batch_size, shuffle=True, collate_fn=set_up_batch_data)

In [11]:
##########################################################################################################################
# model
##########################################################################################################################
class W2V_skipgram(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super(W2V_skipgram, self).__init__()
        self.emb_1 = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size)
        self.emb_2 = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size)
        
    def forward(self, emb1, emb2):
        center = self.emb_1(emb1) ## torch.Size([batch size, 1, 100])
        surround = self.emb_2(emb2) ## torch.Size([batch size, (K+1)* (window_size * 2), 100])
        output = torch.bmm(center, surround.permute(0, 2, 1)) 
        ### [batch size, 1, 100] bmm [batch size, 100, (K+1)* (window_size * 2)] -> [batch size, 1, (K+1)* (window_size * 2)]
        return output


##########################################################################################################################
# loss function: extension of the binary_cross_entropy_with_logits
##########################################################################################################################
class SigmoidBinaryCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(SigmoidBinaryCrossEntropyLoss, self).__init__()
        
    def forward(self, inputs, targets, mask=None):
        inputs, targets, mask = inputs.float(), targets.float(), mask.float()
        res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=mask)
        return res.mean(dim=1)

In [12]:
##########################################################################################################################
# train the word2vec skipgram model
##########################################################################################################################
num_epochs = 5
lr = 0.01
loss_fun = SigmoidBinaryCrossEntropyLoss()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

vocab_size = len(index_to_word)
embedding_size = 100
net = W2V_skipgram(vocab_size, embedding_size)
net = net.to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=lr)
for epoch in range(num_epochs):
    loss_sum, n = 0.0, 0
    for batch in data_loader:
        center, context_negative, mask, label = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device)

        pred = net(center, context_negative)

        loss = (loss_fun(pred.view(label.shape), label, mask) *
                mask.shape[1] / mask.float().sum(dim=1)).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_sum += loss.cpu().item()
        n += 1
    print('epoch: {}, loss: {}.'.format(epoch + 1, loss_sum / n))

epoch: 1, loss: 0.6673071845957383.
epoch: 2, loss: 0.41197761281314177.
epoch: 3, loss: 0.39991736626212215.
epoch: 4, loss: 0.39462959344510906.
epoch: 5, loss: 0.3913673517550913.


In [17]:
def get_similar_tokens(query_token, k, embed):
    W = embed.weight.data
    x = W[word_to_index[query_token]]
    cos = torch.matmul(W, x) / (torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9).sqrt()
    _, topk = torch.topk(cos, k=k+1)
    topk = topk.cpu().numpy()
    for i in topk[1:]:
        print('cosine sim=%.3f: %s' % (cos[i], (index_to_word[i])))

get_similar_tokens('dog', 10, net.emb_1)

cosine sim=0.501: performer
cosine sim=0.473: spacecraft
cosine sim=0.460: cells
cosine sim=0.435: taxi
cosine sim=0.417: sandinistas
cosine sim=0.409: eye
cosine sim=0.406: drug
cosine sim=0.401: movement
cosine sim=0.398: remedy
cosine sim=0.397: brain
