In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import nltk
import random
import numpy as np

In [2]:
def tokenize_corpus(corpus):
    tokens = [x.split() for x in corpus]
    return tokens

In [3]:
train_data = ['Mani is hungry',
              'Do not eat anything',
              'Recommend a restaurant',
              'Good food near restaurant',
              'What will you eat fir dinner ?',
              'I want to eat rice when I am hungry',
              'If you are hungry eat rice',
              'I prefer Wheat to rice for dinner',
              'Wheat is a food',
              'rice is a food',
              'I want to watch movies',
              'All movies are videos but all videos are not movies',
              'Have you seen anything lately?',
              'movies or entertainment recommendation',
              'Show me some funny drama video',
              'Give me some movies of plot with love',
              'Back to the horror comedy movies',
              'Show only funny videos highlights',
              'Ram like funny and comedy movies',
              'Wheat and rice are edible',
              'I like to watch cricket highlights'
             ]

In [4]:
tokenized = tokenize_corpus(train_data)
print tokenized[0]

['Mani', 'is', 'hungry']


In [5]:
from collections import Counter
flatten = lambda l: [item for sublist in l for item in sublist]

In [6]:
word_count = Counter(flatten(tokenized))

In [7]:
#print flatten(tokenized)
print word_count
print list(reversed(word_count.most_common()))[:10]
print list((word_count.most_common()))[:10]

Counter({'movies': 7, 'to': 5, 'rice': 5, 'I': 5, 'are': 4, 'eat': 4, 'videos': 3, 'food': 3, 'is': 3, 'Wheat': 3, 'you': 3, 'funny': 3, 'a': 3, 'hungry': 3, 'and': 2, 'want': 2, 'some': 2, 'highlights': 2, 'me': 2, 'watch': 2, 'not': 2, 'comedy': 2, 'like': 2, 'anything': 2, 'restaurant': 2, 'Show': 2, 'dinner': 2, 'Ram': 1, 'What': 1, 'lately?': 1, 'love': 1, 'recommendation': 1, 'am': 1, 'all': 1, 'Back': 1, 'video': 1, 'Recommend': 1, 'seen': 1, 'fir': 1, 'Do': 1, 'Good': 1, 'for': 1, 'entertainment': 1, 'plot': 1, 'when': 1, 'near': 1, 'only': 1, 'drama': 1, '?': 1, 'cricket': 1, 'Give': 1, 'prefer': 1, 'horror': 1, 'but': 1, 'Mani': 1, 'with': 1, 'edible': 1, 'All': 1, 'of': 1, 'will': 1, 'Have': 1, 'the': 1, 'or': 1, 'If': 1})
[('If', 1), ('or', 1), ('the', 1), ('Have', 1), ('will', 1), ('of', 1), ('All', 1), ('edible', 1), ('with', 1), ('Mani', 1)]
[('movies', 7), ('to', 5), ('rice', 5), ('I', 5), ('are', 4), ('eat', 4), ('videos', 3), ('food', 3), ('is', 3), ('Wheat', 3)]


In [8]:
MIN_COUNT = 2
stopwords = []

In [9]:
# Add the word with min count in stopwords list
for w, c in word_count.items():
    if c < MIN_COUNT:
        if w not in stopwords:
            stopwords.append(w) 
print stopwords
print len(stopwords)

['Ram', 'What', 'lately?', 'love', 'recommendation', 'am', 'all', 'Back', 'video', 'Recommend', 'seen', 'fir', 'Do', 'Good', 'for', 'entertainment', 'plot', 'when', 'near', 'only', 'drama', '?', 'cricket', 'Give', 'prefer', 'horror', 'but', 'Mani', 'with', 'edible', 'All', 'of', 'will', 'Have', 'the', 'or', 'If']
37


In [10]:
vocab = list(set(flatten(tokenized)) - set(stopwords))
print vocab

['and', 'videos', 'is', 'some', 'are', 'want', 'funny', 'Wheat', 'highlights', 'to', 'you', 'comedy', 'a', 'food', 'watch', 'not', 'eat', 'me', 'rice', 'I', 'like', 'anything', 'restaurant', 'Show', 'hungry', 'movies', 'dinner']


In [11]:
word2index = {'<unk>' : 0}
for vo in vocab:
    if word2index.get(vo) is None:
        word2index[vo] = len(word2index)
        
index2word = {v:k for k, v in word2index.items()}
print word2index
print len(word2index)

{'and': 1, 'videos': 2, 'is': 3, 'some': 4, 'are': 5, 'want': 6, 'funny': 7, 'Wheat': 8, 'highlights': 9, 'to': 10, 'you': 11, 'rice': 19, 'a': 13, 'food': 14, 'watch': 15, 'not': 16, 'eat': 17, 'me': 18, 'comedy': 12, 'I': 20, 'like': 21, 'anything': 22, 'restaurant': 23, 'Show': 24, 'hungry': 25, 'movies': 26, 'dinner': 27, '<unk>': 0}
28


In [12]:
WINDOW_SIZE = 5
windows =  flatten([list(nltk.ngrams(['<DUMMY>'] * WINDOW_SIZE + c + ['<DUMMY>'] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in tokenized])

#print windows
train_data = []

for window in windows:
    for i in range(WINDOW_SIZE * 2 + 1):
        # stopwords
        if window[i] in stopwords or window[WINDOW_SIZE] in stopwords: 
            continue # min_count
        if i == WINDOW_SIZE or window[i] == '<DUMMY>': 
            continue
        train_data.append((window[WINDOW_SIZE], window[i]))

In [13]:
def prepare_sequence(seq, word2index):
    idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index["<unk>"], seq))
    #print idxs
    return torch.LongTensor(idxs)

def prepare_word(word, word2index):
    return torch.LongTensor([word2index[word]]) if word2index.get(word) is not None else torch.LongTensor([word2index["<unk>"]])

In [14]:
#print train_data
X_p = []
y_p = []

for tr in train_data:
    X_p.append(prepare_word(tr[0], word2index))
    y_p.append(prepare_word(tr[1], word2index))
    
train_data = list(zip(X_p, y_p))
print train_data[0]

(
 3
[torch.LongTensor of size 1]
, 
 25
[torch.LongTensor of size 1]
)


In [15]:
import torch.utils.data as torchdata
class WordPair(torchdata.Dataset):
    def __init__(self,dataset):
        self.dataset = dataset
        self.length = len(self.dataset)
        
    def __getitem__(self, index):
        return self.dataset[index]
        
    def __len__(self):
        return self.length

In [16]:
train_data = WordPair(train_data)
train_loader = torchdata.DataLoader(dataset=train_data,
                                           batch_size=16, 
                                           shuffle=True)

In [17]:
word_count = Counter(flatten(tokenized)) # unigram distribution
num_total_words = sum([c for w, c in word_count.items() if w not in stopwords])
print num_total_words

80


In [18]:
#print vocab
unigram_table = []
print word_count
for vo in vocab:
    unigram_table.extend([vo] * int(((word_count[vo]/num_total_words)**0.75)/0.001))
print unigram_table
print(len(vocab), len(unigram_table))

Counter({'movies': 7, 'to': 5, 'rice': 5, 'I': 5, 'are': 4, 'eat': 4, 'videos': 3, 'food': 3, 'is': 3, 'Wheat': 3, 'you': 3, 'funny': 3, 'a': 3, 'hungry': 3, 'and': 2, 'want': 2, 'some': 2, 'highlights': 2, 'me': 2, 'watch': 2, 'not': 2, 'comedy': 2, 'like': 2, 'anything': 2, 'restaurant': 2, 'Show': 2, 'dinner': 2, 'Ram': 1, 'What': 1, 'lately?': 1, 'love': 1, 'recommendation': 1, 'am': 1, 'all': 1, 'Back': 1, 'video': 1, 'Recommend': 1, 'seen': 1, 'fir': 1, 'Do': 1, 'Good': 1, 'for': 1, 'entertainment': 1, 'plot': 1, 'when': 1, 'near': 1, 'only': 1, 'drama': 1, '?': 1, 'cricket': 1, 'Give': 1, 'prefer': 1, 'horror': 1, 'but': 1, 'Mani': 1, 'with': 1, 'edible': 1, 'All': 1, 'of': 1, 'will': 1, 'Have': 1, 'the': 1, 'or': 1, 'If': 1})
[]
(27, 0)


In [19]:
def negative_sampling(targets, unigram_table, k):
    batch_size = targets.size(0)
    neg_samples = []
    for i in range(batch_size):
        nsample = []
        target_index = targets[i].tolist()[0]
        while len(nsample) < k: # num of sampling
            neg = random.choice(unigram_table)
            if word2index[neg] == target_index:
                continue
            nsample.append(neg)
        neg_samples.append(prepare_sequence(nsample, word2index).view(1, -1))
    
    return torch.cat(neg_samples)

In [20]:
class SkipgramNegSampling(nn.Module):
    
    def __init__(self, vocab_size, projection_dim):
        super(SkipgramNegSampling, self).__init__()
        self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding
        self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding
        self.logsigmoid = nn.LogSigmoid()
                
        # xavier init
        self.embedding_v.weight.data = nn.init.xavier_uniform(self.embedding_v.weight.data)
        self.embedding_u.weight.data = nn.init.xavier_uniform(self.embedding_u.weight.data)
        
    def forward(self, center_words, target_words, negative_words):
        center_embeds = self.embedding_v(center_words) # B x 1 x D
        target_embeds = self.embedding_u(target_words) # B x 1 x D
        
        neg_embeds = -self.embedding_u(negative_words) # B x K x D
        
        positive_score = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1
        negative_score = torch.sum(neg_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2), 1).view(negs.size(0), -1) # BxK -> Bx1
        
        # loss function
        loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)
        
        return -torch.mean(loss)
    
    def get_embedding(self, inputs):
        embeds_v = self.embedding_v(inputs)
        embeds_u = self.embedding_u(inputs)
        
        return (embeds_v+embeds_u)/2

In [21]:
EMBEDDING_SIZE = 30 
BATCH_SIZE = 256
EPOCH = 100
NEG = 10 # Num of Negative Sampling
losses = []
model = SkipgramNegSampling(len(word2index), EMBEDDING_SIZE)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [25]:
for epoch in range(EPOCH):
    for i,(inputs,targets) in enumerate(train_loader):
        
        negs = negative_sampling(targets, unigram_table, NEG)
        inputs = Variable(inputs) # B x 1
        targets = Variable(targets) # B x 1
        negs = Variable(negs) # B x K
        
        model.zero_grad()

        loss = model(inputs, targets, negs)
        
        loss.backward()
        optimizer.step()
    
        losses.append(loss.data.tolist()[0])
    if epoch % 10 == 0:
        print("Epoch : %d, mean_loss : %.02f" % (epoch, np.mean(losses)))
        losses = []

 


In [26]:
def word_similarity(target,index2word,num=10):
    target_V = model.get_embedding(Variable(prepare_word(target, word2index))).view(1,-1)
    matrix = (model.embedding_u.weight.data + model.embedding_v.weight.data)/2
    cosine_sim = F.cosine_similarity(target_V.data, matrix,dim=1,eps=1e-6)
    v,i = cosine_sim.topk(num+1)
    
    return [[index2word[ii],vv] for ii,vv in zip(i.tolist()[1:],v.tolist()[1:])]

In [31]:
word_similarity("rice",index2word)

[['Wheat', 0.3745597302913666],
 ['not', 0.2742573320865631],
 ['anything', 0.24209095537662506],
 ['Show', 0.24021531641483307],
 ['some', 0.20117609202861786],
 ['a', 0.19513079524040222],
 ['<unk>', 0.1612008959054947],
 ['and', 0.1329057216644287],
 ['to', 0.11647970974445343],
 ['watch', 0.11584723740816116]]

In [35]:
word_similarity("movies",index2word)

[['comedy', 0.43324437737464905],
 ['Show', 0.3047579228878021],
 ['funny', 0.26918506622314453],
 ['highlights', 0.22458085417747498],
 ['videos', 0.20939500629901886],
 ['you', 0.20444093644618988],
 ['eat', 0.1973991096019745],
 ['want', 0.16375625133514404],
 ['food', 0.1545613557100296],
 ['like', 0.14094191789627075]]