In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
class Word2VecFast(nn.Module):
    def __init__(self,vocab_size, embed_dim = 100 ) -> None:
        super().__init__()
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
    
    def forward(self, id):
        return self.embed(id)


In [6]:
max_vocab_size = 100000

training_split_ratio = 0.8

In [7]:
import torch
import torch.nn as nn

In [8]:
def load_text8_data(filepath):
    with open(file=filepath) as file:
        text = file.read()

    return text

In [9]:
text = load_text8_data(r"/home/askar/mlx/week1/team/MLX5W1T1/word2vec/data/text8-1mb")

In [10]:
text_words = text.split()

In [11]:
from collections import Counter

word_counts = Counter(text_words)

In [12]:
most_common_words = word_counts.most_common(max_vocab_size)

In [13]:
total_vocab_size = len(most_common_words)

In [14]:
most_common_words[0:10]

[('the', 10802),
 ('of', 6209),
 ('and', 4839),
 ('in', 3980),
 ('to', 3545),
 ('one', 3377),
 ('a', 3204),
 ('zero', 2387),
 ('is', 2052),
 ('nine', 1774)]

In [15]:
word_to_id = {word:id for id, (word, _) in enumerate(most_common_words)}

In [16]:
id_to_word = {id:word for word, id in word_to_id.items()}

In [17]:
traing_test_cutoff = int(len(text_words) * training_split_ratio)
training_words = text_words[:traing_test_cutoff]
test_words = text_words[traing_test_cutoff:]

In [18]:
print(f"""
Total dataset size: {len(text_words)}
Training dataset size: {len(training_words)}
Test dataset size: {len(test_words)}
      """)


Total dataset size: 175599
Training dataset size: 140479
Test dataset size: 35120
      


In [45]:
test_words

['lane',
 'roadway',
 'and',
 'train',
 'tracks',
 'in',
 'the',
 'same',
 'housing',
 'consequently',
 'eastbound',
 'traffic',
 'westbound',
 'traffic',
 'and',
 'the',
 'alaska',
 'railroad',
 'must',
 'share',
 'the',
 'tunnel',
 'resulting',
 'in',
 'waits',
 'of',
 'two',
 'zero',
 'minutes',
 'or',
 'more',
 'to',
 'enter',
 'as',
 'reflected',
 'on',
 'the',
 'alaska',
 'department',
 'of',
 'transportation',
 'tunnel',
 'website',
 'it',
 'is',
 'now',
 'considered',
 'north',
 'america',
 's',
 'longest',
 'railroad',
 'highway',
 'tunnel',
 'the',
 'alaska',
 'railroad',
 'runs',
 'from',
 'seward',
 'through',
 'anchorage',
 'denali',
 'and',
 'fairbanks',
 'to',
 'north',
 'pole',
 'with',
 'spurs',
 'to',
 'whittier',
 'and',
 'palmer',
 'the',
 'railroad',
 'is',
 'famous',
 'for',
 'its',
 'summertime',
 'passenger',
 'services',
 'but',
 'also',
 'plays',
 'a',
 'vital',
 'part',
 'in',
 'moving',
 'alaska',
 's',
 'natural',
 'resources',
 'such',
 'as',
 'coal',
 'an

In [20]:
model = Word2VecFast(total_vocab_size)

### Training

In [22]:
import torch.optim as optim

optimizer = optim.SGD(model.parameters())

In [23]:
loss_fn = nn.BCEWithLogitsLoss()

In [75]:
def pair_generator(words, context_window):
    total_len = len(words)
    for i, word in enumerate(words):
        start = max(0, i - context_window)
        end = min(total_len, i + context_window + 1)
        context_words = words[start:i] + words[i+1:end]
        yield word, context_words

In [76]:
import random
def negative_pair_generator(words, context_window, number_of_samples):
    total_len = len(words)
    for i, word in enumerate(words):
        neg_samples = []
        for i in range(number_of_samples):
            sample_index = i
            while sample_index > i - context_window and sample_index < i + context_window:
                sample_index = random.randint(0, total_len)
            neg_samples.append(words[sample_index])
        yield word, neg_samples

In [116]:
training_pairs = pair_generator(training_words, 2)
neg_training_pairs = negative_pair_generator(training_words, 2, 3)

In [78]:
for _ in range(10):
    print(next(training_pairs), next(neg_training_pairs))

('anarchism', ['originated', 'as']) ('anarchism', ['socialism', 'to', 'h'])
('originated', ['anarchism', 'as', 'a']) ('originated', ['to', 'one', 'require'])
('as', ['anarchism', 'originated', 'a', 'term']) ('as', ['an', 'the', 'section'])
('a', ['originated', 'as', 'term', 'of']) ('a', ['politics', 'its', 'scientists'])
('term', ['as', 'a', 'of', 'abuse']) ('term', ['aid', 'he', 'the'])
('of', ['a', 'term', 'abuse', 'first']) ('of', ['pattern', 'demographics', 'other'])
('abuse', ['term', 'of', 'first', 'used']) ('abuse', ['zero', 'and', 'the'])
('first', ['of', 'abuse', 'used', 'against']) ('first', ['by', 's', 'duties'])
('used', ['abuse', 'first', 'against', 'early']) ('used', ['viet', 'sections', 'party'])
('against', ['first', 'used', 'early', 'working']) ('against', ['distribution', 'some', 'of'])


In [117]:
from tqdm import tqdm

In [138]:
epochs = 4
batch_size = 1000

training_pairs = pair_generator(training_words, 2)
neg_training_pairs = negative_pair_generator(training_words, 2, 3)

loss_over_epochs = []

for i in range(epochs):
    # Run positive samples
    total_loss = 0
    training_pairs = pair_generator(training_words, 2)
    neg_training_pairs = negative_pair_generator(training_words, 2, 3)
    optimizer.zero_grad()

    for i in tqdm(range(batch_size)):
        try:
            word, pairings = next(training_pairs)
        except StopIteration:
            continue
        for pair in pairings:
            target_embed = model(torch.tensor(word_to_id[word]))
            context_embed = model(torch.tensor(word_to_id[pair]))

            pos_score = torch.matmul(target_embed, context_embed.T)
            pos_label = torch.ones(1).squeeze()


            pos_loss = loss_fn(pos_score, pos_label)
            pos_loss.backward()
            total_loss += pos_loss.item()
    optimizer.step()
    loss_over_epochs.append(total_loss)


100%|██████████| 1000/1000 [01:17<00:00, 12.98it/s]
100%|██████████| 1000/1000 [01:12<00:00, 13.73it/s]
100%|██████████| 1000/1000 [01:17<00:00, 12.92it/s]
100%|██████████| 1000/1000 [01:13<00:00, 13.61it/s]


In [140]:
loss_over_epochs

[10469.086188369109, 9796.39432248345, 9171.822217185116, 8591.228345373096]

In [136]:
model(torch.tensor(52))

'many'