In [None]:
!pip install allennlp==2.5.0
!git clone https://github.com/mhagiwara/realworldnlp.git
%cd realworldnlp

In [None]:
from collections import Counter

import torch
import torch.optim as optim
from allennlp.data.data_loaders import SimpleDataLoader
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.token_embedders import Embedding
from allennlp.training import GradientDescentTrainer
from torch.nn import CosineSimilarity
from torch.nn import functional

In [None]:
from examples.embeddings.word2vec import SkipGramReader

In [None]:
EMBEDDING_DIM = 256
BATCH_SIZE = 256

In [None]:
class SkipGramModel(Model):
    def __init__(self, vocab, embedding_in):
        super().__init__(vocab)
        self.embedding_in = embedding_in
        self.linear = torch.nn.Linear(
            in_features=EMBEDDING_DIM,
            out_features=vocab.get_vocab_size('token_out'),
            bias=False)

    def forward(self, token_in, token_out):
        embedded_in = self.embedding_in(token_in)
        logits = self.linear(embedded_in)
        loss = functional.cross_entropy(logits, token_out)

        return {'loss': loss}

In [None]:
def get_related(token: str, embedding: Model, vocab: Vocabulary, num_synonyms: int = 10):
    """Given a token, return a list of top N most similar words to the token."""
    token_id = vocab.get_token_index(token, 'token_in')
    token_vec = embedding.weight[token_id]
    cosine = CosineSimilarity(dim=0)
    sims = Counter()

    for index, token in vocab.get_index_to_token_vocabulary('token_in').items():
        sim = cosine(token_vec, embedding.weight[index]).item()
        sims[token] = sim

    return sims.most_common(num_synonyms)

In [None]:
reader = SkipGramReader()
text8 = reader.read('https://realworldnlpbook.s3.amazonaws.com/data/text8/text8')

In [None]:
text8 = list(text8)
print(len(text8))
text8 = text8[:1000000]

In [None]:
vocab = Vocabulary.from_instances(
    text8, min_count={'token_in': 5, 'token_out': 5})

In [None]:
data_loader = SimpleDataLoader(text8, batch_size=BATCH_SIZE)
data_loader.index_with(vocab)

In [None]:
embedding_in = Embedding(num_embeddings=vocab.get_vocab_size('token_in'),
                         embedding_dim=EMBEDDING_DIM)

In [None]:
model = SkipGramModel(vocab=vocab,
                      embedding_in=embedding_in)

In [None]:
optimizer = optim.Adam(model.parameters())

In [None]:
trainer = GradientDescentTrainer(
    model=model,
    optimizer=optimizer,
    data_loader=data_loader,
    num_epochs=5,
    cuda_device=-1)

In [None]:
trainer.train()

In [None]:
print(get_related('one', embedding_in, vocab))

In [None]:
print(get_related('december', embedding_in, vocab))