# Skip gram with Negative sampling (SGNS)

Part of **#30DaysOfBasics!**


Contrary to CBOW, Skip-gram takes the word and predict its context within specified window. It is very computation expensive (when vocab size is very big), later, in 2013, Mikolov and the team of researchers provived several extensions of skip gram to improve both the quality of the vectors and the training
speed under the paper titled **'Distributed Representations of Words and Phrases
and their Compositionality** and 'Skip gram with negative sampling' was one of them. 

In [1]:
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
HI_FILE_PATH = '/Users/impyadav/Desktop/data/data/hi/hi_sample.txt'

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

CONTEXT_WINDOW = 3
EMBED_DIM=128
EPOCHS=500

In [3]:
#Prepare the vocabulary

def create_vocab(text_file):
    with open(text_file, 'r', encoding='utf-8', errors='ignore') as f:
        content = f.read()
    vocab = set(content.split())
    word_to_ix = {word:ix for ix, word in enumerate(vocab)}
    ix_to_word = {ix:word for ix, word in enumerate(vocab)}
    return vocab, word_to_ix, ix_to_word, content.split()


hi_vocab, hi_word_to_ix, hi_ix_to_word, hi_content = create_vocab(HI_FILE_PATH)

print(len(hi_vocab))

222


In [4]:
#generate the negative samples for given idx

def generate_neg_samples(idx, list_of_tokens, context_window, k):
    
    pos_index = range(idx-context_window, idx+context_window+1)
    updated_idxs = set(range(len(list_of_tokens))).difference(set(pos_index))
    
    return random.sample(updated_idxs, k)

In [5]:
#generate samples

def generate_sgns_data(list_of_tokens, context_window, k):
    
    sg_data = []
    
    for idx in range(context_window, len(list_of_tokens)-context_window):
        
        temp = []
        pre_context = [list_of_tokens[idx-idx1-1] for idx1 in range(context_window)]
        post_context = [list_of_tokens[idx+idx1+1] for idx1 in range(context_window)]
        
        for context in pre_context + post_context:
            temp.append([list_of_tokens[idx], context, 1])
            
        temp += [[list_of_tokens[idx], list_of_tokens[idx1], 0] for idx1 in generate_neg_samples(idx, list_of_tokens, context_window, k)]
        
        sg_data.append(temp)
        
    return sg_data

In [6]:
sgns_data = generate_sgns_data(hi_content, 4, 8)

since Python 3.9 and will be removed in a subsequent version.
  return random.sample(updated_idxs, k)


In [7]:
print(sgns_data[:2])

[[['लिए', 'के', 1], ['लिए', 'डिलीवरी', 1], ['लिए', 'में', 1], ['लिए', 'अस्पताल', 1], ['लिए', 'लेबर', 1], ['लिए', 'रूम', 1], ['लिए', 'बना', 1], ['लिए', 'है,', 1], ['लिए', 'कॉपी', 0], ['लिए', 'उत्पन्न', 0], ['लिए', 'बेहतर', 0], ['लिए', 'फीसदी', 0], ['लिए', 'बी.', 0], ['लिए', 'कहा-', 0], ['लिए', 'के', 0], ['लिए', 'फिल्म', 0]], [['लेबर', 'लिए', 1], ['लेबर', 'के', 1], ['लेबर', 'डिलीवरी', 1], ['लेबर', 'में', 1], ['लेबर', 'रूम', 1], ['लेबर', 'बना', 1], ['लेबर', 'है,', 1], ['लेबर', 'लेकिन', 1], ['लेबर', '15:1-2,', 0], ['लेबर', 'करता', 0], ['लेबर', 'इस', 0], ['लेबर', 'ही', 0], ['लेबर', 'रीति', 0], ['लेबर', 'को', 0], ['लेबर', 'यह', 0], ['लेबर', 'हजार', 0]]]


In [8]:
#create the network

class SGNS(nn.Module):
    
    def __init__(self, vocab_size, embed_dim):
        super(SGNS, self).__init__()
        self.input_embedding = nn.Embedding(vocab_size, embed_dim)
        self.context_embedding = nn.Embedding(vocab_size, embed_dim)
        
    
    def forward(self, input_word, context):
        
        input_embedding = self.input_embedding(input_word).view(1,-1)
        context_embedding = self.context_embedding(context)
        
        context_embedding = torch.transpose(context_embedding,0,1)
        
        dot = torch.mm(input_embedding, context_embedding)
        
        scores = F.softmax(dot)
        
        return scores

In [9]:
network = SGNS(len(hi_vocab), EMBED_DIM)

In [10]:
loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(network.parameters(), lr=0.001)

In [11]:
#Model Training

for epoch in range(EPOCHS):
    
    total_loss = 0
    
    for mix_data in sgns_data:
        
        target_word = mix_data[0][0]
#         print(target_word)
        context = [item[1] for item in mix_data]
#         print(context)
        
        y_label = torch.unsqueeze(torch.tensor([item[2] for item in mix_data]), 1).float()
        
        target_idx = torch.tensor([hi_word_to_ix[target_word]])
        context_idxs = torch.tensor([hi_word_to_ix[word] for word in context], dtype=torch.long)
                                                                                      
        network.zero_grad()
        
        scores = network(target_idx, context_idxs)
        
        loss = loss_fn(torch.transpose(scores,0,1), y_label)
                                  
        loss.backward()
        
        optimizer.step()
                                  
        total_loss += loss.item()
        
    print('Epoch {}/{} and loss: {}'.format(epoch, EPOCHS, total_loss))                       

Epoch 0/500 and loss: 3665.3501102924347
Epoch 1/500 and loss: 3719.2619240283966


  scores = F.softmax(dot)


Epoch 2/500 and loss: 3702.36297249794
Epoch 3/500 and loss: 3685.57830786705
Epoch 4/500 and loss: 3668.915704727173
Epoch 5/500 and loss: 3652.387958049774
Epoch 6/500 and loss: 3635.9967012405396
Epoch 7/500 and loss: 3619.703288078308
Epoch 8/500 and loss: 3603.5487689971924
Epoch 9/500 and loss: 3587.6073784828186
Epoch 10/500 and loss: 3571.8163356781006
Epoch 11/500 and loss: 3556.1648766994476
Epoch 12/500 and loss: 3540.7042322158813
Epoch 13/500 and loss: 3525.4333777427673
Epoch 14/500 and loss: 3510.3021664619446
Epoch 15/500 and loss: 3495.3516652584076
Epoch 16/500 and loss: 3480.5602235794067
Epoch 17/500 and loss: 3465.9268715381622
Epoch 18/500 and loss: 3451.465877056122
Epoch 19/500 and loss: 3437.1852309703827
Epoch 20/500 and loss: 3423.0838873386383
Epoch 21/500 and loss: 3409.1708998680115
Epoch 22/500 and loss: 3395.454136610031
Epoch 23/500 and loss: 3381.9435057640076
Epoch 24/500 and loss: 3368.648289203644
Epoch 25/500 and loss: 3355.564219236374
Epoch 26/50

Epoch 198/500 and loss: 2111.4603332281113
Epoch 199/500 and loss: 2106.96286380291
Epoch 200/500 and loss: 2102.483209133148
Epoch 201/500 and loss: 2098.020444869995
Epoch 202/500 and loss: 2093.5737779140472
Epoch 203/500 and loss: 2089.1435916423798
Epoch 204/500 and loss: 2084.7299571037292
Epoch 205/500 and loss: 2080.332409143448
Epoch 206/500 and loss: 2075.9509679079056
Epoch 207/500 and loss: 2071.5856779813766
Epoch 208/500 and loss: 2067.23602104187
Epoch 209/500 and loss: 2062.901240706444
Epoch 210/500 and loss: 2058.5803487300873
Epoch 211/500 and loss: 2054.273577809334
Epoch 212/500 and loss: 2049.980792284012
Epoch 213/500 and loss: 2045.702199101448
Epoch 214/500 and loss: 2041.4368464946747
Epoch 215/500 and loss: 2037.183542728424
Epoch 216/500 and loss: 2032.9411259889603
Epoch 217/500 and loss: 2028.7092700004578
Epoch 218/500 and loss: 2024.4863681793213
Epoch 219/500 and loss: 2020.2711700201035
Epoch 220/500 and loss: 2016.0675723552704
Epoch 221/500 and loss:

Epoch 390/500 and loss: 1475.8626005649567
Epoch 391/500 and loss: 1473.5918862819672
Epoch 392/500 and loss: 1471.3295992612839
Epoch 393/500 and loss: 1469.0755769014359
Epoch 394/500 and loss: 1466.8297072649002
Epoch 395/500 and loss: 1464.5919170379639
Epoch 396/500 and loss: 1462.3620351552963
Epoch 397/500 and loss: 1460.1400171518326
Epoch 398/500 and loss: 1457.9257514476776
Epoch 399/500 and loss: 1455.7191257476807
Epoch 400/500 and loss: 1453.520028591156
Epoch 401/500 and loss: 1451.3283721208572
Epoch 402/500 and loss: 1449.1439961194992
Epoch 403/500 and loss: 1446.9671739339828
Epoch 404/500 and loss: 1444.7985790967941
Epoch 405/500 and loss: 1442.6381531953812
Epoch 406/500 and loss: 1440.4858955144882
Epoch 407/500 and loss: 1438.3417307138443
Epoch 408/500 and loss: 1436.2055840492249
Epoch 409/500 and loss: 1434.0774857997894
Epoch 410/500 and loss: 1431.9573429822922
Epoch 411/500 and loss: 1429.8452447652817
Epoch 412/500 and loss: 1427.7413140535355
Epoch 413/50

In [12]:
def get_nearest_neighbors(input_word, model, vocab_size, embedding_dim, n_neighbors=5):
    
    #Integer mapping
    target_idx = torch.tensor([hi_word_to_ix[input_word]])
    
    #loading trained embedding
    all_embeds = model.input_embedding.weight.view(1, vocab_size, embedding_dim)
    
    #input_word mapping
    input_embed = model.input_embedding(torch.tensor([hi_word_to_ix[input_word]]))
    
    #cosine similarity
    similarity_fn = nn.CosineSimilarity()
    
    scores = similarity_fn(input_embed, all_embeds)
    
    top_result = torch.topk(scores, n_neighbors+1)

    pred_scores = [item for item in top_result.values.tolist()[0][1:]]
    pred_indices = [item for item in top_result.indices.tolist()[0][1:]]
    
    preds = [(hi_ix_to_word[item], round(item1, 2)) for item, item1 in 
              zip(pred_indices[1:], pred_scores[1:])]
    
    return preds                                                                                                                    

In [13]:
# predictions Generation
print(get_nearest_neighbors('झगड़ा', network, len(hi_vocab), EMBED_DIM))

[('इसकी', 0.13), ('होता', 0.13), ('लिए', 0.12), ('नियुक्ति', 0.12)]
