In [1]:
import os
import torch
from wdm import LSTMEncoder, LSTMCellDecoder
from embeddings import load_glove_embeddings, Embeddings

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")

Running on the CPU


In [3]:
dim = 50
data_loc = '../data'
model_loc = 'model'
batch = 1
max_length = 30

In [4]:
embeddings = load_glove_embeddings(dim, data_loc)
len(embeddings)

400003

In [5]:
def load_encoder(filename):
    encoder = LSTMEncoder(dim, dim)
    with open(filename, 'rb') as f:
        encoder.load_state_dict(torch.load(f, map_location=torch.device('cpu')))
    return encoder

def load_decoder(filename):
    decoder = LSTMCellDecoder(dim, dim*2, len(embeddings))
    with open(filename, 'rb') as f:
        decoder.load_state_dict(torch.load(f, map_location=torch.device('cpu')))
    return decoder

In [6]:
encoder = load_encoder(os.path.join(model_loc, 'encoder.pt'))
decoder = load_decoder(os.path.join(model_loc, 'decoder.pt'))


In [7]:
from itertools import chain
from typing import List, Iterable
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, PackedSequence

def sentence_to_list(sent: str) -> Iterable[str]:
    return chain([Embeddings.SOS_STR], sent.split(), [Embeddings.EOS_STR])

def to_packed_sequence(tensors: List[torch.Tensor]) -> PackedSequence:
    lens = [len(t) for t in tensors]
    packed = pack_padded_sequence(pad_sequence(tensors), lens, enforce_sorted=False).to(device)
    return packed

def strings_to_batch(strings: List[str]) -> PackedSequence:
    sents_emb = [embeddings.sentence_to_tensor(sentence_to_list(sent)) for sent in strings]
    return to_packed_sequence(sents_emb)

def strings_to_ids(strings: List[str]) -> PackedSequence:
    ids = [embeddings.sentence_to_ids(sentence_to_list(sent)) for sent in strings]
    return to_packed_sequence(ids)

In [8]:
from dataset import Oxford2019Dataset
from torch.utils.data import DataLoader

def make_data_loader(filename: str, file_loc: str = os.path.join(data_loc, 'Oxford-2019')) -> DataLoader:
    dataset = Oxford2019Dataset(data_loc=os.path.join(file_loc, filename))
    data_loader = DataLoader(dataset, batch_size=batch, shuffle=True)
    return data_loader

test_set = make_data_loader('test.txt')

In [79]:
from torch.nn import LogSoftmax, Softmax
from random import randint

encoder.eval()
decoder.eval()
lsm = LogSoftmax(dim=1)
sm = Softmax(dim=1)


def generate_text(word, example):
    word_first_example = ' '.join((word, example))

    encoder_input = embeddings.sentence_to_tensor(sentence_to_list(word_first_example))
    encoder_input = encoder_input.unsqueeze(dim=1)
    e_out, e_hidden = encoder(encoder_input)

    decoder_input = torch.cat((e_hidden[0], e_hidden[1]), dim=1)
    decoder_input = decoder_input.unsqueeze(dim=0)

    result = [Embeddings.SOS_STR]
    while result[-1] != embeddings.EOS_STR and len(result) <= max_length:
        token_emb = embeddings[result[-1]]
        d_out, decoder_input = decoder(token_emb.view(1, 1, -1), decoder_input)
        sort = torch.argsort(sm(d_out))
        id = sort[0][0][randint(0, 5)]
        id = sort[0][0][0]
        token = embeddings.id2word[id.item()]
        result.append(token)
    return ' '.join(result)

In [77]:
all_results = {}
with torch.no_grad():
    for words, defs, examples in test_set:
        for word, definition, example in zip(words, defs, examples):
            text = generate_text(word, example)
            all_results[word] = text

        if len(all_results) > 5:
            break
all_results

{'vespers': '<s> kissane termly termly eleniak termly kissane woundwort bawean 5,430 termly eleniak bawean 5,430 kissane bawean termly eleniak eleniak 5,430 bawean bawean woundwort woundwort termly eleniak bawean eleniak kissane kissane eleniak',
 'ancestor': '<s> bawean woundwort termly termly 5,430 woundwort 5,430 eleniak eleniak termly kissane kissane termly bawean eleniak 5,430 5,430 bawean 5,430 kissane eleniak termly kissane bawean 5,430 bawean termly eleniak kissane kissane',
 'monitor': '<s> eleniak termly bawean 5,430 5,430 5,430 eleniak termly woundwort eleniak 5,430 eleniak eleniak termly kissane bawean 5,430 kissane termly termly termly 5,430 bawean kissane 5,430 bawean kissane eleniak termly termly',
 'canvasback': '<s> 5,430 termly eleniak eleniak 5,430 bawean 5,430 bawean bawean 5,430 bawean termly 5,430 eleniak 5,430 woundwort kissane 5,430 woundwort termly eleniak woundwort woundwort bawean bawean eleniak bawean kissane bawean woundwort',
 'skimmer': '<s> kissane wound

In [80]:
generate_text("guildhall", "from 1709 until the early nineteenth century the goldsmiths' company had their guildhall in werburgh street , close to dublin castle")

'<s> bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean bawean'