In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random

In [2]:
class Word2Vec(nn.Module): #CBOW
  def __init__(self, vocab_size, emb_dim, window, lr):
    super(Word2Vec, self).__init__()
    self.emb_dim = emb_dim
    self.vocab_size = vocab_size
    self.window = window
    self.W = torch.rand((vocab_size, emb_dim), requires_grad=True)
    self.W_ = torch.rand((emb_dim, vocab_size), requires_grad=True)
    self.optimizer = optim.Adam([self.W, self.W_], lr=lr)
    self.criterion = nn.CrossEntropyLoss()
  
  def forward(self, inputs): 
    #inputs : (batch_size, n_seq)
    batch_cost = 0
    for batch_idx in range(len(inputs)):
      seq = inputs[batch_idx]
      cost = 0
      for i in range(len(seq)):
        left = torch.tensor(inputs[batch_idx][max(0,i-self.window):i], dtype = torch.int64)
        right = torch.tensor(inputs[batch_idx][i+1 : i+self.window+1], dtype = torch.int64)

        context_num = torch.cat((left, right), dim = 0)
        emb = torch.mean(self.W[context_num], dim = 0)

        prob = torch.matmul(emb,self.W_).reshape(1,self.vocab_size)
        target = torch.tensor(inputs[batch_idx][i], dtype = torch.int64).reshape(1)        

        loss = self.criterion(prob, target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        cost += loss
      batch_cost += cost/len(seq)
    batch_cost /= len(inputs)
    return batch_cost    

# **Dataset**

In [3]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [4]:
import urllib.request
import zipfile
from lxml import etree
import re
from nltk.tokenize import word_tokenize, sent_tokenize
from tqdm import tqdm

In [5]:
urllib.request.urlretrieve("https://raw.githubusercontent.com/GaoleMeng/RNN-and-FFNN-textClassification/master/ted_en-20160408.xml", filename="ted_en-20160408.xml")

('ted_en-20160408.xml', <http.client.HTTPMessage at 0x7fc9d7593c50>)

In [6]:
targetXML=open('ted_en-20160408.xml', 'r', encoding='UTF8')
target_text = etree.parse(targetXML)

parse_text = '\n'.join(target_text.xpath('//content/text()'))

content_text = re.sub(r'\([^)]*\)', '', parse_text)

sent_text = sent_tokenize(content_text)

normalized_text = []
for string in sent_text:
     tokens = re.sub(r"[^a-z0-9]+", " ", string.lower())
     normalized_text.append(tokens)

result = [word_tokenize(sentence) for sentence in normalized_text]
print(len(result))

273424


In [7]:
import random
corpus, indices = [], []
while len(corpus) < 1000:
  idx = random.randint(0, 273423)
  if idx not in indices :
    corpus.append(result[idx])

In [8]:
print(corpus[100])

['and', 'i', 'don', 't', 'want', 'it', 'to', 'leave', 'me', 'behind']


In [9]:
max_len = max(len(s) for s in corpus)
print(max_len)

121


In [10]:
word_to_index = {}
index_to_word = {}
idx = 2
with tqdm(total=len(corpus)) as pbar:
  for i in range(len(corpus)):
    sent = corpus[i]
    for word in sent :
      if word not in word_to_index :
        word_to_index[word] = idx
        index_to_word[idx] = word
        idx += 1
    pbar.update(1)
word_to_index['<pad>'], index_to_word[0] = 0, '<pad>'
word_to_index['<UNK>'], index_to_word[1] = 1, '<UNK>'

100%|██████████| 1000/1000 [00:00<00:00, 90355.54it/s]


In [11]:
vocab_size = len(word_to_index)
emb_dim = 128
window = 4
batch_size = 128
epochs = 10
print(vocab_size)

3480


In [12]:
hist = []
with tqdm(total = max_len) as pbar:
  for th in range(max_len):
    cnt = 0
    for sent in corpus :
      if len(sent) <= th :
        cnt += 1
    hist.append([th, cnt])
    pbar.update(1)

100%|██████████| 121/121 [00:00<00:00, 2773.54it/s]


In [13]:
for pair in hist :
  if pair[1] >= len(corpus)*0.95:
    max_len = pair[0]
    break
print(max_len)

40


In [14]:
input_data = [[word_to_index[sent[i]] if i <= len(sent) - 1 else 0 for i in range(max_len)]for sent in corpus]

In [15]:
w2v = Word2Vec(vocab_size, emb_dim, window, lr=1e-2)

In [16]:
total_batch = len(input_data) // batch_size + 1
for epoch in range(epochs) :
  avg_cost = 0
  print("========= epoch : {} ==========".format(epoch+1))
  for i in range(total_batch):
    inputs = input_data[i*batch_size : min(len(input_data), (i+1)*batch_size)]
    batch_cost = w2v.forward(inputs)
    print("{} batch_cost : {}".format(i+1, batch_cost))
    avg_cost += batch_cost / total_batch
  print("epoch : {} Average cost : {}".format(epoch+1, avg_cost))

1 batch_cost : 3.6747968196868896
2 batch_cost : 3.077592134475708
3 batch_cost : 3.0428335666656494
4 batch_cost : 3.375246524810791
5 batch_cost : 3.099916458129883
6 batch_cost : 3.115767478942871
7 batch_cost : 2.8684921264648438
8 batch_cost : 3.339341640472412
epoch : 1 Average cost : 3.1992485523223877
1 batch_cost : 3.2010600566864014
2 batch_cost : 2.741875410079956
3 batch_cost : 2.6982786655426025
4 batch_cost : 2.9304511547088623
5 batch_cost : 2.7512705326080322
6 batch_cost : 2.7007217407226562
7 batch_cost : 2.4745993614196777
8 batch_cost : 2.8767917156219482
epoch : 2 Average cost : 2.7968811988830566
1 batch_cost : 2.829974889755249
2 batch_cost : 2.420654535293579
3 batch_cost : 2.4285240173339844
4 batch_cost : 2.6164190769195557
5 batch_cost : 2.499236583709717
6 batch_cost : 2.420802116394043
7 batch_cost : 2.253206491470337
8 batch_cost : 2.6052026748657227
epoch : 3 Average cost : 2.5092525482177734
1 batch_cost : 2.5595922470092773
2 batch_cost : 2.222852945327

In [19]:
def get_wordvector(word):
  idx = word_to_index[word]
  return w2v.W[idx]

In [52]:
from numpy.linalg import norm 
import numpy as np
def most_similar(word):
  wordvec = get_wordvector(word).detach().numpy()
  W = w2v.W.detach().numpy()
  distances = []
  similars = []
  for i in range(len(w2v.W)):
    dist = np.sqrt(np.sum((wordvec - W[i])**2))
    distances.append([dist, i])
  
  distances.sort(key = lambda x : x[0])
  for i in range(1, 11):
    similars.append((index_to_word[distances[i][1]], distances[i][0]))
  return similars   

In [53]:
print(most_similar('man'))

[('stainless', 16.878725), ('<pad>', 16.897211), ('seatbelt', 16.900116), ('reefs', 16.953918), ('tested', 16.977743), ('sale', 16.989044), ('heaven', 17.123827), ('silver', 17.171232), ('monitor', 17.203062), ('sleeping', 17.214302)]
