In [47]:
import numpy as np
from collections import deque
import nltk
import re
import random
from nltk.corpus import brown
from nltk.corpus import gutenberg
nltk.download('gutenberg')
nltk.download('brown')
nltk.download('punkt')
nltk.download('stopwords')


class InputData:
    def __init__(self, sentences, sample):
        self.norm_sentences = []
        self.counter = 0
        self.sample = sample
        self.wordId_frequency_dict = dict()
        self.word_count = 0  #  Number of words (repeated words only count as 1)
        self.word_count_sum = 0  # Total number of words (the number of repeated words also accumulates)
        self.sentence_count = 0  # Number of sentences
        self.id2word_dict = dict()
        self.word2id_dict = dict()
        self._init_dict(sentences)  # Initialize the dictionary
        self.subsampling()
        self.sample_table = []
        self._init_sample_table()
        self.word_pairs_queue = deque()

        print('Word Count is:', self.word_count)
        print('Word Count Sum is', self.word_count_sum)
        print('Sentence Count is:', self.sentence_count)

    def special_match(strg, search=re.compile(r'[^a-z0-9.]').search):
      return not bool(search(strg))


    def subsampling(self):
        
        if self.sample > 0:
            self.word_count_sum = 0
            self.sentence_count = 0

            frequency = np.array(list(self.wordId_frequency_dict.values()))
            z = frequency / np.sum(frequency)
            p = (np.sqrt(z / self.sample) + 1) * (self.sample / z)

            new_norm_sentences = []
            for word_list in self.norm_sentences:
              word_list = [word for word in word_list if p[self.word2id_dict[word]] > random.random()]
              if len(word_list) >= 2:
                self.sentence_count += 1
                self.word_count_sum += len(word_list)
                new_norm_sentences.append(word_list)

            self.norm_sentences = new_norm_sentences

    def normalize(self, word_list):
      sentence = " ".join(word for word in word_list)
      sentence = re.sub(r'[^a-zA-Z\s]', '', sentence)
      sentence = sentence.lower()
      sentence = re.sub(' +', ' ', sentence)
      sentence = sentence.strip()
      norm_word_list = sentence.split(' ')
      if self.sample <= 0:
          stop_words = nltk.corpus.stopwords.words('english')
          norm_word_list_with_out_stop_words = [word for word in norm_word_list if word not in stop_words]
          norm_word_list = norm_word_list_with_out_stop_words

      return norm_word_list
      

    def _init_dict(self, sentences):
        word_freq = dict()
        self.word_count_sum = 0
        for word_list in sentences:
            word_list = self.normalize(word_list)
            if(len(word_list) < 2):
                continue
            self.word_count_sum += len(word_list)
            self.sentence_count += 1
            for word in word_list:
                try:
                    word_freq[word] += 1
                except:
                    word_freq[word] = 1
            self.norm_sentences.append(word_list)

        word_id = 0
        # Initialize word2id_dict, id2word_dict, wordId_frequency_dict dictionary
        for per_word, per_count in word_freq.items():
            self.id2word_dict[word_id] = per_word
            self.word2id_dict[per_word] = word_id
            self.wordId_frequency_dict[word_id] = per_count
            word_id += 1
        self.word_count = len(self.word2id_dict)

    def _init_sample_table(self):
        sample_table_size = 1e8
        frequency = np.array(list(self.wordId_frequency_dict.values())) ** 0.75
        frequency_sum = np.sum(frequency)  # Total word frequency of all words
        ratio_array = frequency / frequency_sum 
        word_count_list = np.round(ratio_array * sample_table_size)
        for word_index, word_freq in enumerate(word_count_list):
            self.sample_table += [word_index] * int(word_freq)  # it generates a list, the content is the id of each word, each id in the list is repeated multiple times, the number of repetitions is the word frequency
        self.sample_table = np.array(self.sample_table)
        print(self.sample_table.shape)

    def generate_positive_pairs(self, window_size, neg_count):
        self.counter += 1
        if not self.norm_sentences[20*(self.counter-1):20*self.counter]:
            self.counter = 1
            self.word_pairs_queue.clear()
        sub_wids = [[self.word2id_dict[word] for word in word_list] for word_list in self.norm_sentences[20*(self.counter-1):20*self.counter]]

        # Find the positive sampling pair (w,v) and add it to the positive sampling queue
        for words in sub_wids:
          sentence_length = len(words)
          for index, center_word in enumerate(words):
            start = index - window_size
            end = index + window_size + 1

            positive_words = []
            for index_2 in range(start,end):
              if 0 <= index_2 < sentence_length and index_2 != index:
                positive_words.append(words[index_2])
              elif index_2 < 0 or index_2 >= sentence_length:
                positive_words.append(self.word_count)
            
            negative_words = np.random.choice(self.sample_table, size=neg_count).tolist()

            self.word_pairs_queue.append((center_word, positive_words, negative_words))

    def get_batch_pairs(self, batch_size, window_size, neg_count):

        while len(self.word_pairs_queue) < batch_size:
          self.generate_positive_pairs(window_size, neg_count)              
              
        result_pairs = []  # Returns a positive sample pair of mini-batch size
        for _ in range(batch_size):
            result_pairs.append(self.word_pairs_queue.popleft())
        return result_pairs


    def evaluate_pairs_count(self):
        return self.word_count_sum


def test():
    sentences = brown.sents(categories=['news'])
    test_data = InputData(sentences,0.0001)
    print(" ".join(word for word in sentences[0]))
    print(" ".join(word for word in test_data.norm_sentences[0]))
    print(" ".join(word for word in sentences[1]))
    print(" ".join(word for word in test_data.norm_sentences[1]))
    # test_data.evaluate_pairs_count()
    pos_pairs = test_data.get_batch_pairs(10, 2, 8)
    print('positive:')
    print(pos_pairs)
    pos_word_pairs = []
    for pair in pos_pairs:
        pos_word_pairs.append((test_data.id2word_dict[pair[0]], [test_data.id2word_dict[i] for i in pair[1] if i != test_data.word_count], [test_data.id2word_dict[i] for i in pair[2] if i != test_data.word_count]))
    print(pos_word_pairs)
    print(len(pos_pairs))

    pos_pairs = test_data.get_batch_pairs(10, 2, 8)
    print('positive:')
    print(pos_pairs)
    pos_word_pairs = []
    for pair in pos_pairs:
        pos_word_pairs.append((test_data.id2word_dict[pair[0]], [test_data.id2word_dict[i] for i in pair[1] if i != test_data.word_count], [test_data.id2word_dict[i] for i in pair[2] if i != test_data.word_count]))
    print(pos_word_pairs)
    print(len(pos_pairs))

    # neg_pair = test_data.get_negative_sampling(pos_pairs, 3)
    # print('negative:')
    # print(neg_pair)
    # neg_word_pair = []
    #for pair in neg_pair:
    #    neg_word_pair.append(
    #        (test_data.id2word_dict[pair[0]], test_data.id2word_dict[pair[1]], test_data.id2word_dict[pair[2]]))
    #print(neg_word_pair)


if __name__ == '__main__':
    test()

[nltk_data] Downloading package gutenberg to /root/nltk_data...
[nltk_data]   Package gutenberg is already up-to-date!
[nltk_data] Downloading package brown to /root/nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
(100002346,)
Word Count is: 12125
Word Count Sum is 49420
Sentence Count is: 4505
The Fulton County Grand Jury said Friday an investigation of Atlanta's recent primary election produced `` no evidence '' that any irregularities took place .
fulton county grand jury said an investigation atlantas recent primary election produced evidence that irregularities took place
The jury further said in term-end presentments that the City Executive Committee , which had over-all charge of the election , `` deserves the praise and thanks of the Ci

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics.pairwise import euclidean_distances


class SkipGramModel(nn.Module):
    def __init__(self, emb_size, emb_dimension):
        super(SkipGramModel, self).__init__()
        self.emb_size = emb_size
        self.emb_dimension = emb_dimension
        self.w_embeddings = nn.Embedding(emb_size, emb_dimension, sparse=True)
        self.v_embeddings = nn.Embedding(emb_size + 1, emb_dimension, sparse=True)
        self._init_emb()

    def _init_emb(self):
        initrange = 0.5 / self.emb_dimension
        self.w_embeddings.weight.data.uniform_(-initrange, initrange) # work
        self.v_embeddings.weight.data.uniform_(-0, 0) # work

    def forward(self, pos_w, pos_v, neg_v):
        emb_w = self.w_embeddings(torch.LongTensor(pos_w))
        emb_v = self.v_embeddings(torch.LongTensor(pos_v))
        neg_emb_v = self.v_embeddings(torch.LongTensor(neg_v))

        score = torch.mul(emb_w.unsqueeze(1), emb_v)
        score = torch.sum(score, dim=2).squeeze()
        score = F.logsigmoid(score)
        score = torch.sum(score, dim=1).squeeze()

        neg_score = torch.mul(emb_w.unsqueeze(1), neg_emb_v)
        neg_score = torch.sum(neg_score, dim=2).squeeze()
        neg_score = F.logsigmoid(-1 * neg_score)
        neg_score = torch.sum(neg_score, dim=1).squeeze()

        # L = log sigmoid (Xw.T * θv) + ∑neg(v) [log sigmoid (-Xw.T * θneg(v))]
        final_score = score + neg_score
        loss = -1 * torch.sum(final_score)
        return loss

    def distance_matrix(self, word_count):
        embedding = self.w_embeddings.weight.data.numpy()[:word_count]
        distance_matrix = euclidean_distances(embedding)
        return distance_matrix


def test():
    model = SkipGramModel(100, 10)
    id2word = dict()
    for i in range(100):
        id2word[i] = str(i)
    pos_w = [0, 2]
    pos_v = [[9,10],[10,12]]
    neg_v = [[23, 42, 74, 32], [32, 24, 62, 53]]
    model.forward(pos_w, pos_v, neg_v)


if __name__ == '__main__':
    test()

In [49]:
import torch.optim as optim
from tqdm import tqdm

# hyper parameters
WINDOW_SIZE = 4 
BATCH_SIZE = 1000  # mini-batch
EMB_DIMENSION = 100  # embedding dimension
LR = 0.01 # Learning rate
NEG_COUNT = 12


class Word2Vec:
    def __init__(self, sentences, sample):
        self.data = InputData(sentences, sample)
        self.model = SkipGramModel(self.data.word_count, EMB_DIMENSION)
        self.lr = LR
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr)

    def train(self):
        print("SkipGram Training......")
        pairs_count = self.data.evaluate_pairs_count()
        print("pairs_count", pairs_count)
        batch_count = pairs_count / BATCH_SIZE
        print("batch_count", batch_count)
        for epoch in range(1,51):
            mean_loss = 0
            process_bar = tqdm(range(int(batch_count)))
            for i in process_bar:
                pos_pairs = self.data.get_batch_pairs(BATCH_SIZE, WINDOW_SIZE, NEG_COUNT)
                pos_w = [int(pair[0]) for pair in pos_pairs]
                pos_v = [pair[1] for pair in pos_pairs]
                neg_v = [pair[2] for pair in pos_pairs]

                self.optimizer.zero_grad()
                loss = self.model.forward(pos_w, pos_v, neg_v)
                loss.backward()
                self.optimizer.step()
                mean_loss += loss

            print("epoch:",epoch,"loss:",mean_loss/int(batch_count))


    def get_distance_matrix(self):
        distance_matrix = self.model.distance_matrix(self.data.word_count)
        return distance_matrix


# if __name__ == '__main__':
#     w2v = Word2Vec(input_file_name='./data.txt', output_file_name="word_embedding.txt")
#     w2v.train()

In [50]:
sentences = brown.sents(categories=['news','reviews','government','hobbies','romance'])
SAMPLE = 0.0002 # use subsampling
w2v = Word2Vec(sentences, SAMPLE)

(100002426,)
Word Count is: 24758
Word Count Sum is 196690
Sentence Count is: 17269


In [51]:
sentences = brown.sents(categories=['news','reviews','government','hobbies','romance'])
SAMPLE = 0 # eliminate stop words
w2v2 = Word2Vec(sentences, SAMPLE)

(99999555,)
Word Count is: 24616
Word Count Sum is 170964
Sentence Count is: 17106


In [52]:
w2v.train()

SkipGram Training......
pairs_count 196690
batch_count 196.69


100%|██████████| 196/196 [00:16<00:00, 11.87it/s]


epoch: 1 loss: tensor(12915.3281, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.89it/s]


epoch: 2 loss: tensor(12333.5928, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.11it/s]


epoch: 3 loss: tensor(12195.7715, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:17<00:00, 11.51it/s]


epoch: 4 loss: tensor(12093.3320, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.15it/s]


epoch: 5 loss: tensor(11996.3818, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.19it/s]


epoch: 6 loss: tensor(11900.4346, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.91it/s]


epoch: 7 loss: tensor(11795.2832, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.19it/s]


epoch: 8 loss: tensor(11690.0762, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.86it/s]


epoch: 9 loss: tensor(11584.6904, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.65it/s]


epoch: 10 loss: tensor(11478.4404, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:15<00:00, 12.38it/s]


epoch: 11 loss: tensor(11365.2578, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:15<00:00, 12.46it/s]


epoch: 12 loss: tensor(11253.7959, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.16it/s]


epoch: 13 loss: tensor(11138.0596, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:15<00:00, 12.64it/s]


epoch: 14 loss: tensor(11017.2705, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.91it/s]


epoch: 15 loss: tensor(10892.2295, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.66it/s]


epoch: 16 loss: tensor(10764.9629, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:15<00:00, 12.30it/s]


epoch: 17 loss: tensor(10628.2715, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.11it/s]


epoch: 18 loss: tensor(10493.9482, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.81it/s]


epoch: 19 loss: tensor(10351.6270, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:19<00:00, 10.27it/s]


epoch: 20 loss: tensor(10211.1025, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.09it/s]


epoch: 21 loss: tensor(10067.9268, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.86it/s]


epoch: 22 loss: tensor(9919.2617, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.89it/s]


epoch: 23 loss: tensor(9769.2764, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:17<00:00, 11.50it/s]


epoch: 24 loss: tensor(9623.7803, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:18<00:00, 10.82it/s]


epoch: 25 loss: tensor(9482.3770, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.05it/s]


epoch: 26 loss: tensor(9330.1553, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.91it/s]


epoch: 27 loss: tensor(9189.7949, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.11it/s]


epoch: 28 loss: tensor(9049.6465, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.10it/s]


epoch: 29 loss: tensor(8911.3760, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:18<00:00, 10.82it/s]


epoch: 30 loss: tensor(8780.0195, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.93it/s]


epoch: 31 loss: tensor(8646.9697, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.14it/s]


epoch: 32 loss: tensor(8521.3115, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.94it/s]


epoch: 33 loss: tensor(8402.7354, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:17<00:00, 11.46it/s]


epoch: 34 loss: tensor(8281.8301, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.94it/s]


epoch: 35 loss: tensor(8170.1294, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.05it/s]


epoch: 36 loss: tensor(8053.9336, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:15<00:00, 12.46it/s]


epoch: 37 loss: tensor(7952.1309, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:15<00:00, 12.50it/s]


epoch: 38 loss: tensor(7857.8984, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.70it/s]


epoch: 39 loss: tensor(7744.8110, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.09it/s]


epoch: 40 loss: tensor(7654.2358, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.63it/s]


epoch: 41 loss: tensor(7561.1694, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.83it/s]


epoch: 42 loss: tensor(7467.4688, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:15<00:00, 12.28it/s]


epoch: 43 loss: tensor(7390.3311, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:15<00:00, 12.56it/s]


epoch: 44 loss: tensor(7303.3486, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.02it/s]


epoch: 45 loss: tensor(7224.8062, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 11.66it/s]


epoch: 46 loss: tensor(7149.9609, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:15<00:00, 12.25it/s]


epoch: 47 loss: tensor(7080.8311, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:17<00:00, 11.41it/s]


epoch: 48 loss: tensor(7000.9526, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:16<00:00, 12.14it/s]


epoch: 49 loss: tensor(6939.3960, grad_fn=<DivBackward0>)


100%|██████████| 196/196 [00:15<00:00, 12.55it/s]

epoch: 50 loss: tensor(6866.7231, grad_fn=<DivBackward0>)





In [53]:
w2v2.train()

SkipGram Training......
pairs_count 170964
batch_count 170.964


100%|██████████| 170/170 [00:14<00:00, 11.70it/s]


epoch: 1 loss: tensor(12976.0547, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.45it/s]


epoch: 2 loss: tensor(12376.5703, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.43it/s]


epoch: 3 loss: tensor(12164.6338, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.76it/s]


epoch: 4 loss: tensor(12000.3662, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 12.00it/s]


epoch: 5 loss: tensor(11865.2188, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 10.85it/s]


epoch: 6 loss: tensor(11750.2432, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.54it/s]


epoch: 7 loss: tensor(11645.7363, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.51it/s]


epoch: 8 loss: tensor(11534.1689, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.23it/s]


epoch: 9 loss: tensor(11415.5732, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.55it/s]


epoch: 10 loss: tensor(11294.0088, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.51it/s]


epoch: 11 loss: tensor(11170.5078, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.60it/s]


epoch: 12 loss: tensor(11048.5518, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:13<00:00, 12.17it/s]


epoch: 13 loss: tensor(10925.1367, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.20it/s]


epoch: 14 loss: tensor(10798.7754, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 10.89it/s]


epoch: 15 loss: tensor(10672.6113, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.20it/s]


epoch: 16 loss: tensor(10546.7412, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.30it/s]


epoch: 17 loss: tensor(10410.6113, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.49it/s]


epoch: 18 loss: tensor(10280.6338, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.39it/s]


epoch: 19 loss: tensor(10143.3203, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 10.84it/s]


epoch: 20 loss: tensor(10004.5078, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.52it/s]


epoch: 21 loss: tensor(9863.5117, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.38it/s]


epoch: 22 loss: tensor(9714.2783, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.28it/s]


epoch: 23 loss: tensor(9569.9229, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.21it/s]


epoch: 24 loss: tensor(9420.4600, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.97it/s]


epoch: 25 loss: tensor(9265.7959, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.07it/s]


epoch: 26 loss: tensor(9108.7617, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.08it/s]


epoch: 27 loss: tensor(8949.3564, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.10it/s]


epoch: 28 loss: tensor(8792.6084, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:13<00:00, 12.19it/s]


epoch: 29 loss: tensor(8633.1113, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.33it/s]


epoch: 30 loss: tensor(8479.1016, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.47it/s]


epoch: 31 loss: tensor(8314.6348, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.29it/s]


epoch: 32 loss: tensor(8150.0991, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:16<00:00, 10.29it/s]


epoch: 33 loss: tensor(8006.3550, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:16<00:00, 10.13it/s]


epoch: 34 loss: tensor(7852.0942, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.37it/s]


epoch: 35 loss: tensor(7700.6304, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.10it/s]


epoch: 36 loss: tensor(7554.8027, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.19it/s]


epoch: 37 loss: tensor(7417.6587, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 10.92it/s]


epoch: 38 loss: tensor(7270.3433, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.75it/s]


epoch: 39 loss: tensor(7136.3037, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 10.89it/s]


epoch: 40 loss: tensor(7002.5840, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 10.81it/s]


epoch: 41 loss: tensor(6876.0957, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.02it/s]


epoch: 42 loss: tensor(6755.6274, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.70it/s]


epoch: 43 loss: tensor(6641.8213, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.22it/s]


epoch: 44 loss: tensor(6526.2568, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.37it/s]


epoch: 45 loss: tensor(6412.3862, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.28it/s]


epoch: 46 loss: tensor(6309.8862, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.01it/s]


epoch: 47 loss: tensor(6211.3330, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.74it/s]


epoch: 48 loss: tensor(6108.8188, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:14<00:00, 11.43it/s]


epoch: 49 loss: tensor(6022.3667, grad_fn=<DivBackward0>)


100%|██████████| 170/170 [00:15<00:00, 11.29it/s]

epoch: 50 loss: tensor(5930.6221, grad_fn=<DivBackward0>)





In [54]:
distance_matrix = w2v.get_distance_matrix()

In [55]:
distance_matrix2 = w2v2.get_distance_matrix()

In [56]:
similar_words = {search_term: [w2v.data.id2word_dict[idx] for idx in distance_matrix[w2v.data.word2id_dict[search_term]].argsort()[1:10]] 
                   for search_term in ['tablespoon','election','sauce', 'ballot','mettwurst','car','player','university','republican']}
similar_words

{'ballot': ['disturbance',
  'nominating',
  'referendum',
  'nonpartisan',
  'warrant',
  'cursing',
  'pardon',
  'curiae',
  'septemberoctober'],
 'car': ['gloriana',
  'jenks',
  'collided',
  'mystical',
  'insured',
  'halfdozen',
  'ida',
  'thermostat',
  'commandeering'],
 'election': ['hearings',
  'justices',
  'decertify',
  'coolest',
  'closeness',
  'advisement',
  'indicating',
  'subpenas',
  'mandate'],
 'mettwurst': ['bratwurst',
  'bockwurst',
  'cervelat',
  'haec',
  'stabat',
  'pergolesis',
  'knackwurst',
  'tallchief',
  'flavus'],
 'player': ['birdied',
  'pensacola',
  'strokes',
  'tee',
  'paired',
  'mano',
  'closest',
  'winnings',
  'fourwood'],
 'republican': ['resentment',
  'nominee',
  'sharkey',
  'bookwalter',
  'bronx',
  'ptc',
  'caucus',
  'laughlin',
  'pardon'],
 'sauce': ['tablespoon',
  'teaspoon',
  'mustard',
  'sweetsour',
  'sauerkraut',
  'teaspoons',
  'dough',
  'worcestershire',
  'tablespoons'],
 'tablespoon': ['teaspoons',
  'wo

In [57]:
similar_words = {search_term: [w2v2.data.id2word_dict[idx] for idx in distance_matrix2[w2v2.data.word2id_dict[search_term]].argsort()[1:10]] 
                   for search_term in ['tablespoon','election','sauce', 'ballot','mettwurst','car','player','university','republican']}
similar_words

{'ballot': ['anonymous',
  'referendum',
  'nominating',
  'nonpartisan',
  'pertained',
  'subdue',
  'wrongful',
  'jurytampering',
  'miscount'],
 'car': ['rental',
  'gloriana',
  'vernava',
  'touring',
  'chauffeurdriven',
  'pezza',
  'rear',
  'driving',
  'renting'],
 'election': ['voters',
  'anonymous',
  'coolest',
  'justices',
  'arrests',
  'crossroads',
  'subpenas',
  'republicanism',
  'confession'],
 'mettwurst': ['bratwurst',
  'bockwurst',
  'knackwurst',
  'oleg',
  'bologna',
  'ambrose',
  'bierce',
  'neuritis',
  'pergolesis'],
 'player': ['strokes',
  'augusta',
  'threeround',
  'tee',
  'birdied',
  'mano',
  'winnings',
  'palmers',
  'tournament'],
 'republican': ['nominee',
  'carcass',
  'humphrey',
  'conservatives',
  'dirksen',
  'hubert',
  'mississippis',
  'liberals',
  'sheeran'],
 'sauce': ['tablespoon',
  'mustard',
  'tablespoons',
  'teaspoon',
  'chili',
  'worcestershire',
  'sauerkraut',
  'sweetsour',
  'teaspoons'],
 'tablespoon': ['teas