In [46]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True
#import pixiedust

import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

from collections import Counter

import pickle
import torch
import torch.optim as optim
from allennlp.data.iterators import BasicIterator
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.token_embedders import Embedding
from allennlp.training.trainer import Trainer
from torch.nn import CosineSimilarity
from torch.nn import functional

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [47]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [48]:
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

Tesla K80
Memory Usage:
Allocated: 0.0 GB
Cached:    0.1 GB


In [49]:
from word2vec import SkipGramReader

In [50]:
EMBEDDING_DIM = 256
BATCH_SIZE = 256

In [51]:
class SkipGramModel(Model):
    def __init__(self, vocab, embedding_in):
        super().__init__(vocab)
        self.embedding_in = embedding_in
        self.linear = torch.nn.Linear(
            in_features=EMBEDDING_DIM,
            out_features=vocab.get_vocab_size('token_out'),
            bias=False)

    def forward(self, token_in, token_out):
        embedded_in = self.embedding_in(token_in)
        logits = self.linear(embedded_in)
        loss = functional.cross_entropy(logits, token_out)

        return {'loss': loss}

In [52]:
def get_related(token: str, embedding: Model, vocab: Vocabulary, num_synonyms: int = 10):
    """Given a token, return a list of top N most similar words to the token."""
    token_id = vocab.get_token_index(token, 'token_in')
    token_vec = embedding.weight[token_id]
    cosine = CosineSimilarity(dim=0)
    sims = Counter()

    for index, token in vocab.get_index_to_token_vocabulary('token_in').items():
        sim = cosine(token_vec, embedding.weight[index]).item()
        sims[token] = sim

    return sims.most_common(num_synonyms)

In [53]:
RAW_FILE='./inputs/text8'
SAVE_VOCAB="./outputs/text81e6.vocab"

reader = SkipGramReader(kept_tokens=int(1e6))
text8 = reader.read(RAW_FILE)
vocab = Vocabulary.from_instances(
    text8, min_count={'token_in': 5, 'token_out': 5})

vocab.save_to_files(SAVE_VOCAB)

In [54]:
vocab = Vocabulary.from_files("./outputs/text81e6.vocab")

INFO:allennlp.data.vocabulary:Loading token dictionary from ./outputs/text81e6.vocab.


In [55]:
vocab.print_statistics()

INFO:allennlp.data.vocabulary:Vocabulary statistics cannot be printed since dataset instances were not used for its construction.


In [56]:
iterator = BasicIterator(batch_size=BATCH_SIZE)
iterator.index_with(vocab)

In [57]:
embedding_in = Embedding(num_embeddings=vocab.get_vocab_size('token_in'),
                         embedding_dim=EMBEDDING_DIM)

model = SkipGramModel(vocab=vocab,
                      embedding_in=embedding_in)
model.to(device)
optimizer = optim.Adam(model.parameters())
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=text8,
                  num_epochs=5, serialization_dir='./outputs/model_save',cuda_device=0)
trainer.train()

In [58]:
model = SkipGramModel(vocab=vocab,
                      embedding_in=embedding_in)

with open("./outputs/model_save/best.th", 'rb') as f:
    model.load_state_dict(torch.load(f))

In [59]:
print(get_related('one', embedding_in, vocab))

[('one', 1.0), ('nine', 0.8553415536880493), ('eight', 0.7905139327049255), ('six', 0.7360547780990601), ('seven', 0.7323039770126343), ('five', 0.7249915599822998), ('four', 0.7088095545768738), ('three', 0.6642517447471619), ('two', 0.6638907790184021), ('zero', 0.5777777433395386)]


In [60]:
print(get_related('december', embedding_in, vocab))

[('december', 1.0), ('january', 0.5533888339996338), ('climatolgy', 0.4994937777519226), ('november', 0.4932716488838196), ('june', 0.49156075716018677), ('october', 0.4878476560115814), ('march', 0.4642040431499481), ('september', 0.45833346247673035), ('april', 0.4539671838283539), ('yorktown', 0.4322149455547333)]
