In [1]:
import numpy as np
import skipgram as sg
from collections import defaultdict

In [2]:
import nltk
def download_nltk_data_if_needed(dataset_name):
    try:
        nltk.data.find(f'corpora/{dataset_name}')
    except LookupError:
        nltk.download(dataset_name)
        
download_nltk_data_if_needed('treebank')

In [3]:
tagged_corpus = nltk.corpus.treebank.tagged_sents()
print(tagged_corpus[0])

[('Pierre', 'NNP'), ('Vinken', 'NNP'), (',', ','), ('61', 'CD'), ('years', 'NNS'), ('old', 'JJ'), (',', ','), ('will', 'MD'), ('join', 'VB'), ('the', 'DT'), ('board', 'NN'), ('as', 'IN'), ('a', 'DT'), ('nonexecutive', 'JJ'), ('director', 'NN'), ('Nov.', 'NNP'), ('29', 'CD'), ('.', '.')]


In [4]:
tokenized_corpus_with_tags = [sg.tokenize_with_pos(sentence) for sentence in tagged_corpus]
flattened_corpus_with_tags = [token for sentence in tokenized_corpus_with_tags for token in sentence]

In [5]:
print(flattened_corpus_with_tags[:10])

['pierre_NNP', 'vinken_NNP', ',_,', '61_CD', 'years_NNS', 'old_JJ', ',_,', 'will_MD', 'join_VB', 'the_DT']


In [6]:
vocab = defaultdict(lambda: len(vocab))
word_indices = [vocab[token] for token in flattened_corpus_with_tags]

In [7]:
index_to_token = {index: token for token, index in vocab.items()}

In [8]:
vocab_size = len(vocab)
emb_dim = 100
window_size = 2
neg_samples = 5
epochs = 100
learning_rate = 0.01

In [9]:
word_embeddings = sg.train(vocab_size, emb_dim, flattened_corpus_with_tags, window_size, neg_samples, epochs, learning_rate)

Epoch 1/100 completed.
Epoch 2/100 completed.
Epoch 3/100 completed.
Epoch 4/100 completed.
Epoch 5/100 completed.
Epoch 6/100 completed.
Epoch 7/100 completed.
Epoch 8/100 completed.
Epoch 9/100 completed.
Epoch 10/100 completed.
Epoch 11/100 completed.
Epoch 12/100 completed.
Epoch 13/100 completed.
Epoch 14/100 completed.
Epoch 15/100 completed.
Epoch 16/100 completed.
Epoch 17/100 completed.
Epoch 18/100 completed.
Epoch 19/100 completed.
Epoch 20/100 completed.
Epoch 21/100 completed.
Epoch 22/100 completed.
Epoch 23/100 completed.
Epoch 24/100 completed.
Epoch 25/100 completed.
Epoch 26/100 completed.
Epoch 27/100 completed.
Epoch 28/100 completed.
Epoch 29/100 completed.
Epoch 30/100 completed.
Epoch 31/100 completed.
Epoch 32/100 completed.
Epoch 33/100 completed.
Epoch 34/100 completed.
Epoch 35/100 completed.
Epoch 36/100 completed.
Epoch 37/100 completed.
Epoch 38/100 completed.
Epoch 39/100 completed.
Epoch 40/100 completed.
Epoch 41/100 completed.
Epoch 42/100 completed.
E

In [10]:
word_embeddings

array([[ 0.26138577, -0.49349093, -0.51128775, ...,  0.12386352,
        -0.38667834,  0.19027361],
       [ 0.38441327, -0.39835801, -0.33673531, ...,  0.27986087,
        -0.35998727, -0.05360633],
       [ 0.63743511,  0.71259862,  0.45862903, ...,  0.41736575,
         0.85968269,  0.85482232],
       ...,
       [ 0.07240138,  0.60318292,  0.23129892, ...,  0.01354019,
         0.3034849 ,  0.1182057 ],
       [ 0.03860551, -0.09502112,  0.22884016, ...,  0.12103769,
        -0.07995741, -0.59478295],
       [-0.44216613, -0.21115849, -0.29153576, ..., -0.22529254,
        -0.35627789, -0.39913826]])

In [20]:
test_word_base = 'deal'  # word to look for
test_word_indices = [index for word, index in vocab.items() if word.startswith(test_word_base + '_')]

if test_word_indices:
    print(f"Found '{test_word_base}' in vocab with the following indices: {test_word_indices}")
    for test_word_index in test_word_indices:
        if test_word_index < word_embeddings.shape[0]:
            neighbors = sg.find_nearest_neighbors(test_word_index, word_embeddings, index_to_token, top_n=5)
            print(f"Nearest neighbors for index {test_word_index}:")
            for neighbor, similarity in neighbors:
                print(f"  {neighbor}: {similarity}")
        else:
            print(f"Index {test_word_index} out of bounds for embeddings array.")
else:
    print(f"'{test_word_base}' not found in vocabulary.")


Found 'deal' in vocab with the following indices: [3291, 5376, 7248]
Nearest neighbors for index 3291:
  hit_NN: 0.8258289887642533
  seafood_NN: 0.8220139911993182
  fairness_NN: 0.8200012648875471
  len_NN: 0.8141770668603461
  pretext_NN: 0.8104006519552144
Nearest neighbors for index 5376:
  notify_VB: 0.8194625763269121
  disclose_VB: 0.8152288825383291
  step_VB: 0.8104190268102567
  impart_VB: 0.8074770382601183
  accompany_VB: 0.8058699687123654
Nearest neighbors for index 7248:
  prefer_VBP: 0.8094294565770472
  consider_VBP: 0.8042769439060169
  own_VBP: 0.8038776500111564
  differ_VBP: 0.8012875261869283
  evoke_VBP: 0.8001068381253825
