In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
import datasets
import re
import random

pl.seed_everything(42)

shakespeare = datasets.load_dataset('tiny_shakespeare')["train"][0]["text"].lower()

In [78]:
sentences = [re.sub(r"[.,:;!?\"'-]", "", s.lower()).split() for s in shakespeare.split(".")][:1000]
vocab = sorted(set([w for s in sentences for w in s]))
word_to_idx = { word: idx for idx, word in enumerate(vocab) }

In [79]:
look_around = 2
train = [[word_to_idx[w] for w in s] for s in sentences if len(s) > (look_around * 2)]

def windows():
    for ids in train:
        for i in range(len(ids) - (look_around * 2)):
            cx = ids[i : (i + look_around * 2 + 1)]
            middle = cx.pop(look_around)
            yield (torch.tensor(cx), torch.tensor(middle))

In [96]:
class Word2Vec(pl.LightningModule):
    def __init__(self, vocab_size, embedding_dim = 20):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = x.mean(dim=1)
        x = self.linear(x)
        return x
    
    def training_step(self, batch, batch_idx):
        context, middle = batch
        return F.cross_entropy(self(context), middle)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def training_epoch_end(self, outs):
        print(" first ~= second: ", self.compare("first", "second"))

    def lookup(self, word):
        return self.embedding(torch.tensor(word_to_idx[word])).detach()

    def compare(self, a, b):
        #F.cosine_similarity?
        from numpy.linalg import norm
        a = self.lookup(a)
        b = self.lookup(b)
        return a.dot(b)/norm(a)/norm(b)
    
    def predict(self, word):
        id = self(torch.tensor(word_to_idx[word]).view(1, 1, 1)).detach().argmax()
        return vocab[id]


model = Word2Vec(len(vocab))
trainer = pl.Trainer(max_epochs=100)
trainer.fit(model, torch.utils.data.DataLoader(list(windows()), batch_size=200, shuffle=True))


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type      | Params
----------------------------------------
0 | embedding | Embedding | 71.6 K
1 | linear    | Linear    | 75.2 K
----------------------------------------
146 K     Trainable params
0         Non-trainable params
146 K     Total params
0.587     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 97/97 [00:00<00:00, 146.61it/s, loss=8.09, v_num=143] first ~= second:  tensor(-0.1279)
Epoch 1: 100%|██████████| 97/97 [00:00<00:00, 131.72it/s, loss=7.87, v_num=143] first ~= second:  tensor(-0.1237)
Epoch 2: 100%|██████████| 97/97 [00:00<00:00, 144.55it/s, loss=7.59, v_num=143] first ~= second:  tensor(-0.1150)
Epoch 3: 100%|██████████| 97/97 [00:00<00:00, 143.43it/s, loss=7.24, v_num=143] first ~= second:  tensor(-0.0958)
Epoch 4: 100%|██████████| 97/97 [00:00<00:00, 134.00it/s, loss=6.9, v_num=143]  first ~= second:  tensor(-0.0671)
Epoch 5: 100%|██████████| 97/97 [00:00<00:00, 136.57it/s, loss=6.68, v_num=143] first ~= second:  tensor(-0.0329)
Epoch 6: 100%|██████████| 97/97 [00:00<00:00, 139.98it/s, loss=6.47, v_num=143] first ~= second:  tensor(0.0009)
Epoch 7: 100%|██████████| 97/97 [00:00<00:00, 117.75it/s, loss=6.46, v_num=143] first ~= second:  tensor(0.0294)
Epoch 8: 100%|██████████| 97/97 [00:00<00:00, 154.91it/s, loss=6.3, v_num=143]  first ~= s

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 97/97 [00:00<00:00, 132.34it/s, loss=3.86, v_num=143]


In [109]:
for w in random.sample(vocab, 20):
    print(w, model.predict(w))

mischief overtaen
embracements his
troublesome percussion
misery our
used up
kinder woollen
death standst
bred a
broil weeded
sets so
lieutenant i
heavens piercing
leanness the
knew by
methoughti straight
would invincible
surer are
banish you
sign in
selves to


In [112]:
for a, b in [["first", "second"], ["first", "third"], ["second", "third"], ["first", "servingman"], ["first", "citizen"], ["second", "citizen"], ["first", "senator"], ["lord", "god"]]:
    print(a, "~=", b, model.compare(a, b))

first ~= second tensor(0.4789)
first ~= third tensor(0.2472)
second ~= third tensor(0.4937)
first ~= servingman tensor(0.1155)
first ~= citizen tensor(0.2037)
second ~= citizen tensor(0.4524)
first ~= senator tensor(0.1423)
lord ~= god tensor(0.3285)
