# DataLoader

In [384]:
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np


class TXTDataset(Dataset):
    
    def __init__(self):
        with open('/Users/georgychernousov/studying-ml/word2vec/reviews.txt','r') as f:
            reviews = f.readlines()
        self.words = ' '.join(reviews).split()[:1000000]
        vocab = set(self.words)
        vocab_size = len(vocab)
        self.n_namples = vocab_size
        self.word_to_ix = {word:ix for ix, word in enumerate(vocab)}
        self.ix_to_word = {ix:word for ix, word in enumerate(vocab)}
        self.neg_samples_amount = 10 # кратное контекстному окну
        
    ### cbow loader
    def __getitem__(self, index):
        """
        return x, y, where x - context words, y target word
        """
        left_context = self.words[index-2:index]
        left_bias = 2 - len(left_context)
        right_context = self.words[index+1:index+3+left_bias]
        if len(right_context) < 2:
            left_context.extend(self.words[index-4:index-2])
        
        context = torch.tensor([self.word_to_ix[i] for i in [*left_context, *right_context]], dtype=torch.long)
        target = torch.tensor(self.word_to_ix[self.words[index]], dtype=torch.long)
        return context, target
    
    ### skipgram loader
    # def __getitem__(self, index):
    #     left_context = self.words[index-2:index]
    #     left_bias = 2 - len(left_context)
    #     right_context = self.words[index+1:index+3+left_bias]
    #     if len(right_context) < 2:
    #         left_context.extend(self.words[index-4:index-2])
    #     
    #     context = torch.tensor([self.word_to_ix[i] for i in [*left_context, *right_context]])
    #     
    #     center = self.word_to_ix[self.words[index]]
    #     context_neg = np.random.choice(self.n_namples, self.neg_samples_amount, replace=False)
    #     
    #     return center, context, context_neg
        
    def __len__(self):
        return self.n_namples
    
dataset = TXTDataset()
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True)
dataiter = iter(dataloader)
data_1 = dataiter.next()
data_1

[tensor([[21150, 22911, 21145,  1879],
         [25601, 20455, 17335, 18605]]),
 tensor([28782,  4746])]

# CBOW model (without negative sampling)

In [385]:
import torch.nn as nn

class CBOW(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()

        #out: 1 x emdedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, 128)
        self.activation_function1 = nn.ReLU()
        
        #out: 1 x vocab_size
        self.linear2 = nn.Linear(128, vocab_size)
        self.activation_function2 = nn.LogSoftmax(dim = -1)
        self.init_emb(embedding_dim)
        
    def init_emb(self, embedding_dim):
        """
        init the weight as original word2vec do.
        :return: None
        """
        initrange = 0.5 / embedding_dim
        self.embeddings.weight.data.uniform_(-initrange, initrange)
        # self.v_embeddings.weight.data.uniform_(0, 0)
        

    def forward(self, inputs):
        embeds = self.embeddings(inputs)
        embeds = torch.mean(embeds, dim=1)
        out = self.linear1(embeds)
        out = self.activation_function1(out)
        out = self.linear2(out)
        out = self.activation_function2(out)
        return out

    def get_word_emdedding(self, word):
        word = torch.tensor([word_to_ix[word]])
        return self.embeddings(word).view(1,-1)

# Skipgram model

In [373]:
import torch.nn as nn
import torch.nn.functional as F

class SkipGram(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGram, self).__init__()
        
        self.in_emb = nn.Embedding(vocab_size, embedding_dim)
        self.out_emb = nn.Embedding(vocab_size, embedding_dim)
        
    def forward(self, center_word, context_words, neg_context):
        '''
        center_word: центральное слово, [batch_size,]
        context_words: Слова вокруг окна контекста появляются вокруг [Batch_size * 2)]]
        neg_context: нет слов вокруг центрального слова, от отрицательной выборки [batch_size, (window_size * 2 * k)]
        return: loss
        '''
        center_word_emb = self.in_emb(center_word)  # [batch_szie,embed_size]
        context_words_emb = self.out_emb(context_words)  # [batch,(2*C),embed_size]
        neg_emb = self.out_emb(neg_context)  # [batch, (2*C * K),embed_size]
            
        log_pos = torch.bmm(context_words_emb,center_word_emb.unsqueeze(2)).squeeze()
        log_neg = torch.bmm(neg_emb, -center_word_emb.unsqueeze(2)).squeeze()
        
        log_pos = F.logsigmoid(log_pos).sum(1)
        log_neg = F.logsigmoid(log_neg).sum(1)
        
        loss = log_pos + log_neg
        return -loss

# Train CBOW

In [387]:
EMDEDDING_DIM = 100
model = CBOW(len(dataset), EMDEDDING_DIM)
loss_function = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

dataset = TXTDataset()
dataloader = DataLoader(dataset=dataset, batch_size=1024, shuffle=True)

from tqdm import tqdm
for epoch in range(50):
    trainingloss = 0
    for context, target in tqdm(dataloader):
        optimizer.zero_grad()
        # offset = torch.cumsum(torch.tensor([0, *[4 for i in list(range(context.size()[0] - 1))]]), dim=0)
        log_probs = model(context)
        loss = loss_function(log_probs, target)
        loss.backward()
        optimizer.step()
        trainingloss += loss.item()
    # if epoch % 1 == 0:
    print(f'Epoch {epoch}, total loss {trainingloss}')

100%|██████████| 32/32 [00:15<00:00,  2.05it/s]


Epoch 0, total loss 332.4214029312134


100%|██████████| 32/32 [00:16<00:00,  1.99it/s]


Epoch 1, total loss 332.4070873260498


100%|██████████| 32/32 [00:15<00:00,  2.08it/s]


Epoch 2, total loss 332.3927688598633


100%|██████████| 32/32 [00:16<00:00,  1.99it/s]


Epoch 3, total loss 332.3777160644531


100%|██████████| 32/32 [00:15<00:00,  2.09it/s]


Epoch 4, total loss 332.3621606826782


100%|██████████| 32/32 [00:15<00:00,  2.01it/s]


Epoch 5, total loss 332.3455619812012


100%|██████████| 32/32 [00:16<00:00,  1.97it/s]


Epoch 6, total loss 332.33208656311035


100%|██████████| 32/32 [00:17<00:00,  1.84it/s]


Epoch 7, total loss 332.31744384765625


100%|██████████| 32/32 [00:16<00:00,  1.96it/s]


Epoch 8, total loss 332.30278301239014


100%|██████████| 32/32 [00:15<00:00,  2.01it/s]


Epoch 9, total loss 332.2868366241455


 16%|█▌        | 5/32 [00:02<00:14,  1.82it/s]


KeyboardInterrupt: 

# Train SkipGram

In [383]:
EMDEDDING_DIM = 100
model = SkipGram(len(dataset), EMDEDDING_DIM)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

dataset = TXTDataset()
dataloader = DataLoader(dataset=dataset, batch_size=1024, shuffle=True)

from tqdm import tqdm
for epoch in range(50):
    trainingloss = 0
    for center, context, context_neg in tqdm(dataloader):
        optimizer.zero_grad()
        loss = model(center, context, context_neg).mean()
        loss.backward()
        optimizer.step()
        trainingloss += loss.item()
    # if epoch % 1 == 0:
    print(f'Epoch {epoch}, total loss {trainingloss}')

100%|██████████| 32/32 [00:27<00:00,  1.17it/s]


Epoch 0, total loss 1817.8382377624512


100%|██████████| 32/32 [00:25<00:00,  1.24it/s]


Epoch 1, total loss 1821.8741874694824


100%|██████████| 32/32 [00:24<00:00,  1.33it/s]


Epoch 2, total loss 1808.5103187561035


100%|██████████| 32/32 [00:24<00:00,  1.33it/s]


Epoch 3, total loss 1799.1431198120117


100%|██████████| 32/32 [00:24<00:00,  1.31it/s]


Epoch 4, total loss 1791.234474182129


100%|██████████| 32/32 [00:24<00:00,  1.30it/s]


Epoch 5, total loss 1783.8088264465332


100%|██████████| 32/32 [00:25<00:00,  1.26it/s]


Epoch 6, total loss 1780.1335906982422


100%|██████████| 32/32 [00:25<00:00,  1.26it/s]


Epoch 7, total loss 1777.1827812194824


100%|██████████| 32/32 [00:25<00:00,  1.23it/s]


Epoch 8, total loss 1763.8979873657227


100%|██████████| 32/32 [00:24<00:00,  1.29it/s]


Epoch 9, total loss 1765.3404273986816


100%|██████████| 32/32 [00:26<00:00,  1.23it/s]


Epoch 10, total loss 1757.2235641479492


 41%|████      | 13/32 [00:11<00:16,  1.17it/s]


KeyboardInterrupt: 