<a href="https://colab.research.google.com/github/jdasam/aat3020-2023/blob/main/notebooks/1_Word2Vec_training_practice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import string

In [None]:
!wget "https://raw.githubusercontent.com/amephraim/nlp/master/texts/J.%20K.%20Rowling%20-%20Harry%20Potter%201%20-%20Sorcerer's%20Stone.txt"

--2023-03-14 14:37:18--  https://raw.githubusercontent.com/amephraim/nlp/master/texts/J.%20K.%20Rowling%20-%20Harry%20Potter%201%20-%20Sorcerer's%20Stone.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 439742 (429K) [text/plain]
Saving to: ‘J. K. Rowling - Harry Potter 1 - Sorcerer's Stone.txt.3’


2023-03-14 14:37:18 (8.10 MB/s) - ‘J. K. Rowling - Harry Potter 1 - Sorcerer's Stone.txt.3’ saved [439742/439742]



In [None]:
def remove_punctuation(x):
  return x.translate(''.maketrans('', '', string.punctuation))

def make_tokenized_corpus(corpus):
  out= [ [y.lower() for y in remove_punctuation(sentence).split(' ') if y] for sentence in corpus]
  return [x for x in out if x!=[]]

In [None]:
with open("J. K. Rowling - Harry Potter 1 - Sorcerer's Stone.txt", 'r') as f:
  strings = f.readlines()
sample_text = "".join(strings).replace('\n', ' ').replace('Mr.', 'mr').replace('Mrs.', 'mrs').split('. ')


In [None]:
corpus = make_tokenized_corpus(sample_text)
corpus[1]

['they',
 'were',
 'the',
 'last',
 'people',
 'youd',
 'expect',
 'to',
 'be',
 'involved',
 'in',
 'anything',
 'strange',
 'or',
 'mysterious',
 'because',
 'they',
 'just',
 'didnt',
 'hold',
 'with',
 'such',
 'nonsense']

In [None]:
from collections import Counter

def get_entire_words(corpus):
  return sorted(list(set([y for x in corpus for y in x])))

def word_to_idx(unique_word_list):
  return {x:i for i, x in enumerate(unique_word_list)}

entire_words = get_entire_words(corpus)
print(f"Num entire unique words: {len(entire_words)}")
# filter by min count
word_counter = Counter([y for x in corpus for y in x])
min_count = 2
entire_words = [x for x in entire_words if word_counter[x] >= min_count]
print(f"Num entire unique words after filtering: {len(entire_words)}")
word_to_idx_dict = word_to_idx(entire_words)

Num entire unique words: 6038
Num entire unique words after filtering: 3450


In [None]:
def make_word_pair(corpus, window_size=3):
  pair_list = []
  for sentence in corpus:
    for i, word in enumerate(sentence):
      for j in range(max(i-window_size, 0), min(i+window_size+1, len(sentence))):
        if j==i:
          continue
        context_word = sentence[j]
        pair_list.append((word, context_word))
  return pair_list
pair_list = make_word_pair(corpus)

In [None]:
def make_word_pair_for_cbow(corpus, window_size=3):
  pair_list = []
  for sentence in corpus:
    for i, word in enumerate(sentence):
      context_words_for_wrd = []
      for j in range(max(i-window_size, 0), min(i+window_size+1, len(sentence))):
        if j==i:
          continue
        context_word = sentence[j]
        context_words_for_wrd.append(context_word)
      pair_list.append((word, context_words_for_wrd))
  return pair_list
pair_list = make_word_pair_for_cbow(corpus)

In [None]:
pair_list[:20]

[('harry', 'potter'),
 ('harry', 'and'),
 ('harry', 'the'),
 ('potter', 'harry'),
 ('potter', 'and'),
 ('potter', 'the'),
 ('potter', 'sorcerers'),
 ('and', 'harry'),
 ('and', 'potter'),
 ('and', 'the'),
 ('and', 'sorcerers'),
 ('and', 'stone'),
 ('the', 'harry'),
 ('the', 'potter'),
 ('the', 'and'),
 ('the', 'sorcerers'),
 ('the', 'stone'),
 ('the', 'chapter'),
 ('sorcerers', 'potter'),
 ('sorcerers', 'and')]

In [None]:
len(pair_list)

409784

In [None]:
num_vocab = len(word_to_idx_dict)
dim_emb = 50

word_u_mat = torch.randn(num_vocab, dim_emb, requires_grad=True)
word_v_mat = torch.randn(num_vocab, dim_emb, requires_grad=True)


In [None]:
pair = pair_list[0]
print(pair)

center_word = word_to_idx_dict[pair[0]]
window_word = word_to_idx_dict[pair[1]]
print(center_word, window_word)

center_vec = word_v_mat[center_word]
window_vec = word_u_mat[window_word]

dot_product = (center_vec * window_vec).sum()

on_entire_vocab = torch.matmul(center_vec, word_u_mat.T)
prob = torch.exp(dot_product) / torch.exp(on_entire_vocab).sum(0)

('harry', 'potter')
2373 3827


In [None]:
from tqdm import tqdm
total_log_prob = 0
for i, pair in tqdm(enumerate(pair_list)):
  center_word = word_to_idx_dict[pair[0]]
  window_word = word_to_idx_dict[pair[1]]

  center_vec = word_v_mat[center_word]
  window_vec = word_u_mat[window_word]

  dot_product = (center_vec * window_vec).sum()

  on_entire_vocab = torch.matmul(center_vec, word_u_mat.T)
  prob = torch.exp(dot_product) / torch.exp(on_entire_vocab).sum(0)
  log_prob = -torch.log(prob+1e-8)
  total_log_prob += log_prob.item()
  break
total_log_prob /= len(pair_list)

0it [00:00, ?it/s]
