<a href="https://colab.research.google.com/github/fannix/nlp_notebook/blob/master/word2vec.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!wget http://mattmahoney.net/dc/text8.zip

--2020-04-15 19:39:34--  http://mattmahoney.net/dc/text8.zip
Resolving mattmahoney.net (mattmahoney.net)... 67.195.197.75
Connecting to mattmahoney.net (mattmahoney.net)|67.195.197.75|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 31344016 (30M) [application/zip]
Saving to: ‘text8.zip’


2020-04-15 19:40:08 (897 KB/s) - ‘text8.zip’ saved [31344016/31344016]



In [3]:
! unzip text8.zip

Archive:  text8.zip
  inflating: text8                   


In [0]:
def any2unicode(text, encoding='utf8', errors='strict'):
  """Convert `text` (bytestring in given encoding or unicode) to unicode.
  """
  return str(text, encoding, errors=errors)

MAX_WORDS_IN_BATCH = 10000

import smart_open

class Text8Corpus(object):
  def __init__(self, fname, max_sentence_length=MAX_WORDS_IN_BATCH):
    self.fname = fname
    self.max_sentence_length = max_sentence_length

  def __iter__(self):
    # the entire corpus is one gigantic line -- there are no sentence marks at all
    # so just split the sequence of toens arbitrarily: 1 sentence = 1000 tokens
    sentence, rest = [], b''
    with smart_open.smart_open(self.fname) as fin:
      while True:
        text = rest + fin.read(8192)
        if text == rest:
          words = any2unicode(text).split()
          sentence.extend(words)
          if sentence:
            yield sentence
          break
        last_token = text.rfind(b' ')
        words, rest = (any2unicode(text[:last_token]).split(),
                       text[last_token:].strip()) if last_token >= 0 else ([], text)
        sentence.extend(words)
        while len(sentence) >= self.max_sentence_length:
          yield sentence[:self.max_sentence_length]
          sentence = sentence[self.max_sentence_length:]

In [5]:
corpus = Text8Corpus("text8")
from collections import Counter
counter = Counter()
for i, e in enumerate(corpus):
  counter.update(e)

  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL


In [6]:
counter.most_common(5)

[('the', 1061396),
 ('of', 593677),
 ('and', 416629),
 ('one', 411764),
 ('in', 372201)]

In [0]:
VOCAB_SIZE = 50000
word2idx = {'UNK': 0}
for word, _ in counter.most_common(VOCAB_SIZE-1):
  word2idx[word] = len(word2idx)

idx2word = dict(zip(word2idx.values(), word2idx.keys()))

In [0]:
import torch
from torch.utils.data import DataLoader, Dataset

In [9]:
class Text8Dataset(Dataset):

  def __init__(self, max_sentence_length=MAX_WORDS_IN_BATCH):
    super().__init__()
    corpus = Text8Corpus("text8", max_sentence_length)
    self.x = []
    self.y = []
    for sentence in corpus:
      for i in range(1, len(sentence) - 1):
        left = word2idx.get(sentence[i - 1], 0)
        word = word2idx.get(sentence[i], 0)
        right = word2idx.get(sentence[i + 1], 0)
        self.x.append(word)
        self.y.append(left)
        self.x.append(word)
        self.y.append(right)
  
  def __len__(self):
    return len(self.x)

  def __getitem__(self, i):
    return self.x[i], self.y[i]

dataset = Text8Dataset()
data_loader = DataLoader(dataset, 128)

for x, y in data_loader:
  print(x, y)
  break

  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL


tensor([ 3081,  3081,    12,    12,     6,     6,   195,   195,     2,     2,
         3134,  3134,    46,    46,    59,    59,   156,   156,   128,   128,
          742,   742,   477,   477, 10572, 10572,   134,   134,     1,     1,
        27350, 27350,     2,     2,     1,     1,   103,   103,   855,   855,
            3,     3,     1,     1, 15068, 15068,     0,     0,     2,     2,
            1,     1,   151,   151,   855,   855,  3581,  3581,     1,     1,
          195,   195,    11,    11,   191,   191,    59,    59,     5,     5,
            6,     6, 10713, 10713,   215,   215,     7,     7,  1325,  1325,
          105,   105,   455,   455,    20,    20,    59,    59,  2732,  2732,
          363,   363,     7,     7,  3673,  3673,     1,     1,   709,   709,
            2,     2,   372,   372,    27,    27,    41,    41,    37,    37,
           54,    54,   540,   540,    98,    98,    12,    12,     6,     6,
         1424,  1424,  2758,  2758,    19,    19,   568,   568])

In [0]:
from torch import nn, optim
cuda = torch.device('cuda')

class SkipGram(nn.Module):
  def __init__(self, num_word, num_dim):
    super().__init__()
    self.embedding_in = nn.Embedding(num_word, num_dim)
    self.linear = nn.Linear(num_dim, num_word)
  
  def forward(self, X):
    em = self.embedding_in(X)
    linear = self.linear(em)
    return linear
    

In [0]:
skip_gram = SkipGram(VOCAB_SIZE, 128).to(cuda)

sgd = optim.Adam(skip_gram.parameters())
criterion = nn.CrossEntropyLoss()

In [0]:
for epoch in range(2):
  running_loss = 0
  for i, e in enumerate(data_loader):
    input, target = e
    input = input.to(cuda)
    target = target.to(cuda)

    sgd.zero_grad()
    out = skip_gram(input)
    batch_loss = criterion(out, target)
    batch_loss.backward()
    sgd.step()

    running_loss += batch_loss.item()

    if i % 100 == 99:
      print('[%d, %5d] loss: %.3f' % (epoch, i, running_loss / 100))
      running_loss = 0

In [0]:
torch.save(skip_gram.state_dict(), "skip_gram2.pt")

In [15]:
skip_gram.state_dict()

OrderedDict([('embedding_in.weight',
              tensor([[-0.1953,  0.1023,  0.5457,  ..., -0.3030,  0.2190,  0.2364],
                      [-0.1474,  0.0351,  0.3400,  ..., -0.6073, -0.0495,  0.0216],
                      [-0.6118,  1.0723,  0.0366,  ..., -0.2548,  0.0650,  0.0518],
                      ...,
                      [ 0.4745,  0.3977, -0.5545,  ..., -0.6421,  0.7280,  1.2740],
                      [ 0.1979,  0.3360,  0.9031,  ..., -2.4365, -1.1855, -1.1454],
                      [-0.2805, -0.9166,  0.1060,  ...,  0.4408, -0.6335, -1.7914]],
                     device='cuda:0')),
             ('linear.weight',
              tensor([[ 4.0786, -6.3992, -4.4308,  ...,  5.4683, -4.6958, -5.5532],
                      [ 4.1408, -6.3077, -4.5007,  ...,  5.4353, -4.6623, -5.4429],
                      [ 4.1072, -6.5657, -4.3885,  ...,  5.4375, -4.6753, -5.7935],
                      ...,
                      [ 4.2185, -6.2939, -4.7321,  ...,  5.0268, -4.7996, -5.2787

In [0]:
from sklearn.metrics.pairwise import cosine_similarity

In [22]:
weight = skip_gram.state_dict()['embedding_in.weight'].cpu()
# word_in = word2idx['one']
word_in = word2idx['ball']
li = []

for i in range(VOCAB_SIZE):
  if i != word_in:
    similarity = cosine_similarity(
        weight[word_in].reshape(1, -1),
        weight[i].reshape(1, -1))
    li.append((similarity, idx2word[i]))
li.sort(reverse=True)
li[:20]

[(array([[0.47643974]], dtype=float32), 'team'),
 (array([[0.43092996]], dtype=float32), 'football'),
 (array([[0.41306]], dtype=float32), 'player'),
 (array([[0.41072005]], dtype=float32), 'flower'),
 (array([[0.4083807]], dtype=float32), 'bar'),
 (array([[0.39572728]], dtype=float32), 'line'),
 (array([[0.38700283]], dtype=float32), 'champion'),
 (array([[0.38408875]], dtype=float32), 'triple'),
 (array([[0.37671763]], dtype=float32), 'smooth'),
 (array([[0.37598515]], dtype=float32), 'jump'),
 (array([[0.3727573]], dtype=float32), 'label'),
 (array([[0.37117153]], dtype=float32), 'town'),
 (array([[0.3709255]], dtype=float32), 'push'),
 (array([[0.36844778]], dtype=float32), 'disc'),
 (array([[0.36836493]], dtype=float32), 'rod'),
 (array([[0.3680771]], dtype=float32), 'sky'),
 (array([[0.36607373]], dtype=float32), 'yellow'),
 (array([[0.365142]], dtype=float32), 'visitor'),
 (array([[0.3649454]], dtype=float32), 'round'),
 (array([[0.36466426]], dtype=float32), 'green')]