In [53]:
import numpy as np
from collections import deque
import nltk
import re
import random
from nltk.corpus import brown
from nltk.corpus import gutenberg
nltk.download('gutenberg')
nltk.download('brown')
nltk.download('punkt')


class InputData:
    def __init__(self, sentences):
        self.norm_sentences = []
        self.counter = 0
        self.wordId_frequency_dict = dict()
        self.word_count = 0  #  Number of words (repeated words only count as 1)
        self.word_count_sum = 0  # Total number of words (the number of repeated words also accumulates)
        self.sentence_count = 0  # Number of sentences
        self.id2word_dict = dict()
        self.word2id_dict = dict()
        self._init_dict(sentences)  # Initialize the dictionary
        self.subsampling()
        self.sample_table = []
        self._init_sample_table()
        self.word_pairs_queue = deque()

        print('Word Count is:', self.word_count)
        print('Word Count Sum is', self.word_count_sum)
        print('Sentence Count is:', self.sentence_count)

    def special_match(strg, search=re.compile(r'[^a-z0-9.]').search):
      return not bool(search(strg))


    def subsampling(self):
        t = 0.0003
        frequency = np.array(list(self.wordId_frequency_dict.values()))
        z = frequency / sum(frequency)
        p = (np.sqrt(z / t) + 1) * (t / z)

        for index, word_list in enumerate(self.norm_sentences):
          word_list = [word for word in word_list if p[self.word2id_dict[word]] > random.random()]
          self.norm_sentences[index] = word_list

    def normalize(self, word_list):
      sentence = " ".join(word for word in word_list)
      sentence = re.sub(r'[^a-zA-Z\s]', '', sentence)
      sentence = sentence.lower()
      sentence = re.sub(' +', ' ', sentence)
      sentence = sentence.strip()
      norm_word_list = sentence.split(' ')

      return norm_word_list
       

    def _init_dict(self, sentences):
        word_freq = dict()
        for word_list in sentences:
            word_list = self.normalize(word_list)
            if(len(word_list) < 2):
                continue
            self.word_count_sum += len(word_list)
            self.sentence_count += 1
            for word in word_list:
                try:
                    word_freq[word] += 1
                except:
                    word_freq[word] = 1
            self.norm_sentences.append(word_list)

        word_id = 0
        # Initialize word2id_dict, id2word_dict, wordId_frequency_dict dictionary
        for per_word, per_count in word_freq.items():
            self.id2word_dict[word_id] = per_word
            self.word2id_dict[per_word] = word_id
            self.wordId_frequency_dict[word_id] = per_count
            word_id += 1
        self.word_count = len(self.word2id_dict)

    def _init_sample_table(self):
        sample_table_size = 1e8
        frequency = np.array(list(self.wordId_frequency_dict.values())) ** 0.75
        frequency_sum = sum(frequency)  # Total word frequency of all words
        ratio_array = frequency / frequency_sum 
        word_count_list = np.round(ratio_array * sample_table_size)
        for word_index, word_freq in enumerate(word_count_list):
            self.sample_table += [word_index] * int(word_freq)  # it generates a list, the content is the id of each word, each id in the list is repeated multiple times, the number of repetitions is the word frequency
        self.sample_table = np.array(self.sample_table)
        print(self.sample_table.shape)

    def generate_positive_pairs(self, window_size, neg_count):
        self.counter += 1
        if not self.norm_sentences[20*(self.counter-1):20*self.counter]:
            self.counter = 1
            self.word_pairs_queue.clear()
        sub_wids = [[self.word2id_dict[word] for word in word_list] for word_list in self.norm_sentences[20*(self.counter-1):20*self.counter]]

        # Find the positive sampling pair (w,v) and add it to the positive sampling queue
        for words in sub_wids:
          sentence_length = len(words)
          for index, center_word in enumerate(words):
            start = index - window_size
            end = index + window_size + 1

            positive_words = []
            for index_2 in range(start,end):
              if 0 <= index_2 < sentence_length and index_2 != index:
                positive_words.append(words[index_2])
              elif index_2 < 0 or index_2 >= sentence_length:
                positive_words.append(self.word_count)
            
            negative_words = np.random.choice(self.sample_table, size=neg_count).tolist()

            self.word_pairs_queue.append((center_word, positive_words, negative_words))

    def get_batch_pairs(self, batch_size, window_size, neg_count):

        while len(self.word_pairs_queue) < batch_size:
          self.generate_positive_pairs(window_size, neg_count)              
              
        result_pairs = []  # Returns a positive sample pair of mini-batch size
        for _ in range(batch_size):
            result_pairs.append(self.word_pairs_queue.popleft())
        return result_pairs


    def evaluate_pairs_count(self):
        return self.word_count_sum


def test():
    sentences = brown.sents(categories=['news'])
    test_data = InputData(sentences)
    print(" ".join(word for word in sentences[0]))
    print(" ".join(word for word in test_data.norm_sentences[0]))
    print(" ".join(word for word in sentences[1]))
    print(" ".join(word for word in test_data.norm_sentences[1]))
    # test_data.evaluate_pairs_count()
    pos_pairs = test_data.get_batch_pairs(10, 2, 8)
    print('positive:')
    print(pos_pairs)
    pos_word_pairs = []
    for pair in pos_pairs:
        pos_word_pairs.append((test_data.id2word_dict[pair[0]], [test_data.id2word_dict[i] for i in pair[1] if i != test_data.word_count], [test_data.id2word_dict[i] for i in pair[2] if i != test_data.word_count]))
    print(pos_word_pairs)
    print(len(pos_pairs))

    pos_pairs = test_data.get_batch_pairs(10, 2, 8)
    print('positive:')
    print(pos_pairs)
    pos_word_pairs = []
    for pair in pos_pairs:
        pos_word_pairs.append((test_data.id2word_dict[pair[0]], [test_data.id2word_dict[i] for i in pair[1] if i != test_data.word_count], [test_data.id2word_dict[i] for i in pair[2] if i != test_data.word_count]))
    print(pos_word_pairs)
    print(len(pos_pairs))

    # neg_pair = test_data.get_negative_sampling(pos_pairs, 3)
    # print('negative:')
    # print(neg_pair)
    # neg_word_pair = []
    #for pair in neg_pair:
    #    neg_word_pair.append(
    #        (test_data.id2word_dict[pair[0]], test_data.id2word_dict[pair[1]], test_data.id2word_dict[pair[2]]))
    #print(neg_word_pair)


if __name__ == '__main__':
    test()

[nltk_data] Downloading package gutenberg to /root/nltk_data...
[nltk_data]   Package gutenberg is already up-to-date!
[nltk_data] Downloading package brown to /root/nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
(100002346,)
Word Count is: 12125
Word Count Sum is 86971
Sentence Count is: 4573
The Fulton County Grand Jury said Friday an investigation of Atlanta's recent primary election produced `` no evidence '' that any irregularities took place .
fulton county grand jury friday an investigation atlantas recent primary election produced evidence that any irregularities took place
The jury further said in term-end presentments that the City Executive Committee , which had over-all charge of the election , `` deserves the praise and thanks of the City of Atlanta '' for the manner in which the election was conducted .
jury further termend presentments that city

In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics.pairwise import euclidean_distances


class SkipGramModel(nn.Module):
    def __init__(self, emb_size, emb_dimension):
        super(SkipGramModel, self).__init__()
        self.emb_size = emb_size
        self.emb_dimension = emb_dimension
        self.w_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True)
        self.v_embeddings = nn.Embedding(emb_size + 1, emb_dimension, sparse=True)
        self._init_emb()

    def _init_emb(self):
        initrange = 0.5 / self.emb_dimension
        self.w_embeddings.weight.data.uniform_(-initrange, initrange) # work
        self.v_embeddings.weight.data.uniform_(-0, 0) # work

    def forward(self, pos_w, pos_v, neg_v):
        emb_w = self.w_embeddings(torch.LongTensor(pos_w))
        emb_v = self.v_embeddings(torch.LongTensor(pos_v))
        neg_emb_v = self.v_embeddings(torch.LongTensor(neg_v))

        score = torch.mul(emb_w.unsqueeze(1), emb_v)
        score = torch.sum(score, dim=2).squeeze()
        score = F.logsigmoid(score)
        score = torch.sum(score, dim=1).squeeze()

        neg_score = torch.mul(emb_w.unsqueeze(1), neg_emb_v)
        neg_score = torch.sum(neg_score, dim=2).squeeze()
        neg_score = F.logsigmoid(-1 * neg_score)
        neg_score = torch.sum(neg_score, dim=1).squeeze()

        # L = log sigmoid (Xw.T * θv) + ∑neg(v) [log sigmoid (-Xw.T * θneg(v))]
        final_score = score + neg_score
        loss = -1 * torch.sum(final_score)
        return loss

    def distance_matrix(self, word_count):
        embedding = self.w_embeddings.weight.data.numpy()[:word_count]
        distance_matrix = euclidean_distances(embedding)
        return distance_matrix


def test():
    model = SkipGramModel(100, 10)
    id2word = dict()
    for i in range(100):
        id2word[i] = str(i)
    pos_w = [0, 2]
    pos_v = [[9,10],[10,12]]
    neg_v = [[23, 42, 74, 32], [32, 24, 62, 53]]
    model.forward(pos_w, pos_v, neg_v)


if __name__ == '__main__':
    test()

In [37]:
import torch.optim as optim
from tqdm import tqdm

# hyper parameters
WINDOW_SIZE = 2 
BATCH_SIZE = 1000  # mini-batch
EMB_DIMENSION = 100  # embedding dimension
LR = 0.01 # Learning rate
NEG_COUNT = 12


class Word2Vec:
    def __init__(self, sentences):
        self.data = InputData(sentences)
        self.model = SkipGramModel(self.data.word_count, EMB_DIMENSION)
        self.lr = LR
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr)

    def train(self):
        print("SkipGram Training......")
        pairs_count = self.data.evaluate_pairs_count()
        print("pairs_count", pairs_count)
        batch_count = pairs_count / BATCH_SIZE
        print("batch_count", batch_count)
        for epoch in range(1,51):
            mean_loss = 0
            process_bar = tqdm(range(int(batch_count)))
            for i in process_bar:
                pos_pairs = self.data.get_batch_pairs(BATCH_SIZE, WINDOW_SIZE, NEG_COUNT)
                pos_w = [int(pair[0]) for pair in pos_pairs]
                pos_v = [pair[1] for pair in pos_pairs]
                neg_v = [pair[2] for pair in pos_pairs]

                self.optimizer.zero_grad()
                loss = self.model.forward(pos_w, pos_v, neg_v)
                loss.backward()
                self.optimizer.step()
                mean_loss += loss

            print("epoch:",epoch,"loss:",mean_loss/int(batch_count))


    def get_distance_matrix(self):
        distance_matrix = self.model.distance_matrix(self.data.word_count)
        return distance_matrix


# if __name__ == '__main__':
#     w2v = Word2Vec(input_file_name='./data.txt', output_file_name="word_embedding.txt")
#     w2v.train()

In [38]:
sentences = brown.sents(categories=['news','reviews','government','hobbies','romance'])
w2v = Word2Vec(sentences)

(100002426,)
Word Count is: 24758
Word Count Sum is 313120
Sentence Count is: 17525


In [39]:
w2v.train()

SkipGram Training......
pairs_count 313120
batch_count 313.12


100%|██████████| 313/313 [00:24<00:00, 12.67it/s]


epoch: 1 loss: tensor(10559.4941, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 13.00it/s]


epoch: 2 loss: tensor(9552.1973, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:25<00:00, 12.34it/s]


epoch: 3 loss: tensor(9081.1162, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:23<00:00, 13.19it/s]


epoch: 4 loss: tensor(8703.0986, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.94it/s]


epoch: 5 loss: tensor(8543.0039, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 13.00it/s]


epoch: 6 loss: tensor(8271.6953, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.69it/s]


epoch: 7 loss: tensor(8199.7383, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.95it/s]


epoch: 8 loss: tensor(7927.1582, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:23<00:00, 13.44it/s]


epoch: 9 loss: tensor(7886.7163, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:23<00:00, 13.27it/s]


epoch: 10 loss: tensor(7592.1196, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 13.02it/s]


epoch: 11 loss: tensor(7562.4355, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:23<00:00, 13.06it/s]


epoch: 12 loss: tensor(7249.7588, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.86it/s]


epoch: 13 loss: tensor(7222.5439, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:23<00:00, 13.14it/s]


epoch: 14 loss: tensor(6905.2021, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.99it/s]


epoch: 15 loss: tensor(6869.1694, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.66it/s]


epoch: 16 loss: tensor(6571.0273, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.82it/s]


epoch: 17 loss: tensor(6526.8955, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 13.04it/s]


epoch: 18 loss: tensor(6242.3540, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.80it/s]


epoch: 19 loss: tensor(6200.8833, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.66it/s]


epoch: 20 loss: tensor(5939.1787, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.78it/s]


epoch: 21 loss: tensor(5895.2720, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.87it/s]


epoch: 22 loss: tensor(5658.0757, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.69it/s]


epoch: 23 loss: tensor(5615.2202, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.61it/s]


epoch: 24 loss: tensor(5398.9575, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:23<00:00, 13.54it/s]


epoch: 25 loss: tensor(5364.9268, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.80it/s]


epoch: 26 loss: tensor(5167.7378, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 13.00it/s]


epoch: 27 loss: tensor(5129.7109, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.84it/s]


epoch: 28 loss: tensor(4964.8901, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.86it/s]


epoch: 29 loss: tensor(4923.8481, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.98it/s]


epoch: 30 loss: tensor(4775.6602, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.91it/s]


epoch: 31 loss: tensor(4736.5415, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.87it/s]


epoch: 32 loss: tensor(4601.7192, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.97it/s]


epoch: 33 loss: tensor(4569.3848, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.94it/s]


epoch: 34 loss: tensor(4456.5142, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:25<00:00, 12.40it/s]


epoch: 35 loss: tensor(4419.6387, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.74it/s]


epoch: 36 loss: tensor(4317.5361, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:23<00:00, 13.22it/s]


epoch: 37 loss: tensor(4282.6782, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:23<00:00, 13.25it/s]


epoch: 38 loss: tensor(4201.4062, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.62it/s]


epoch: 39 loss: tensor(4163.0273, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.83it/s]


epoch: 40 loss: tensor(4097.1724, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:25<00:00, 12.42it/s]


epoch: 41 loss: tensor(4053.0471, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.83it/s]


epoch: 42 loss: tensor(4000.3977, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.76it/s]


epoch: 43 loss: tensor(3958.9170, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.69it/s]


epoch: 44 loss: tensor(3913.2500, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.71it/s]


epoch: 45 loss: tensor(3871.3762, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:24<00:00, 12.52it/s]


epoch: 46 loss: tensor(3841.8474, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:26<00:00, 11.84it/s]


epoch: 47 loss: tensor(3790.5144, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:23<00:00, 13.11it/s]


epoch: 48 loss: tensor(3772.9133, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:23<00:00, 13.16it/s]


epoch: 49 loss: tensor(3729.4109, grad_fn=<DivBackward0>)


100%|██████████| 313/313 [00:23<00:00, 13.15it/s]

epoch: 50 loss: tensor(3709.0132, grad_fn=<DivBackward0>)





In [40]:
distance_matrix = w2v.get_distance_matrix()

In [52]:
similar_words = {search_term: [w2v.data.id2word_dict[idx] for idx in distance_matrix[w2v.data.word2id_dict[search_term]].argsort()[1:10]] 
                   for search_term in ['tablespoon','election','sauce', 'ballot','mettwurst','car','player','university','republican']}
similar_words

{'ballot': ['referendum',
  'nominating',
  'nonpartisan',
  'subcommittee',
  'gangland',
  'miscount',
  'retirements',
  'welled',
  'buddies'],
 'car': ['collided',
  'streetcar',
  'heading',
  'drives',
  'pezza',
  'renting',
  'morals',
  'headlights',
  'backward'],
 'election': ['induction',
  'commenting',
  'seconddegree',
  'burglary',
  'races',
  'beveling',
  'overexpose',
  'regattas',
  'calmest'],
 'mettwurst': ['bratwurst',
  'bockwurst',
  'niger',
  'stewardesses',
  'knackwurst',
  'harelips',
  'apergillus',
  'flavus',
  'copland'],
 'player': ['birdied',
  'rosburg',
  'reportedly',
  'lewisohn',
  'dookiyoon',
  'invented',
  'tee',
  'chapters',
  'babe'],
 'republican': ['nominee',
  'partys',
  'dirksen',
  'gubernatorial',
  'hillel',
  'leverett',
  'distraught',
  'forum',
  'independents'],
 'sauce': ['pineapple',
  'tablespoon',
  'sweetsour',
  'teaspoon',
  'savory',
  'chunks',
  'worcestershire',
  'sauerkraut',
  'horseradish'],
 'tablespoon': ['