In [5]:
import os
import torch
from wdm import LSTMEncoder, LSTMCellDecoder
from embeddings import load_glove_embeddings, Embeddings

In [6]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")

Running on the CPU


In [17]:
dim = 50
data_loc = '../data'
model_loc = 'model'
batch = 1
max_length = 30

In [8]:
embeddings = load_glove_embeddings(dim, data_loc)
len(embeddings)

400003

In [12]:
def load_encoder(filename):
    encoder = LSTMEncoder(dim, dim)
    with open(filename, 'rb') as f:
        encoder.load_state_dict(torch.load(f, map_location=torch.device('cpu')))
    return encoder

def load_decoder(filename):
    decoder = LSTMCellDecoder(dim, dim*2, len(embeddings))
    with open(filename, 'rb') as f:
        decoder.load_state_dict(torch.load(f, map_location=torch.device('cpu')))
    return decoder

In [13]:
encoder = load_encoder(os.path.join(model_loc, 'encoder.pt'))
decoder = load_decoder(os.path.join(model_loc, 'decoder.pt'))


In [4]:
from itertools import chain
from typing import List, Iterable, Callable
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, PackedSequence

def sentence_to_list(sent: str) -> Iterable[str]:
    return chain([Embeddings.SOS_STR], sent.split(), [Embeddings.EOS_STR])

def to_packed_sequence(tensors: List[torch.Tensor]) -> PackedSequence:
    lens = [len(t) for t in tensors]
    packed = pack_padded_sequence(pad_sequence(tensors), lens, enforce_sorted=False).to(device)
    return packed

def strings_to_batch(strings: List[str]) -> PackedSequence:
    sents_emb = [embeddings.sentence_to_tensor(sentence_to_list(sent)) for sent in strings]
    return to_packed_sequence(sents_emb)

def strings_to_ids(strings: List[str]) -> PackedSequence:
    ids = [embeddings.sentence_to_ids(sentence_to_list(sent)) for sent in strings]
    return to_packed_sequence(ids)

In [19]:
from dataset import Oxford2019Dataset
from torch.utils.data import DataLoader

def make_data_loader(filename: str, file_loc: str = os.path.join(data_loc, 'Oxford-2019')) -> DataLoader:
    dataset = Oxford2019Dataset(data_loc=os.path.join(file_loc, filename))
    data_loader = DataLoader(dataset, batch_size=batch, shuffle=True)
    return data_loader

test_set = make_data_loader('test.txt')

In [21]:
from torch.nn import LogSoftmax

encoder.eval()
decoder.eval()
lsm = LogSoftmax(dim=1)

all_results = []

for words, defs, examples in test_set:
    # print(words[0], defs[0], examples[0])
    word_first_examples = (' '.join((w, e)) for w, e in zip(words, examples))

    batch_results = []
    for word_first_example, definition in zip(word_first_examples, defs):
        e_out, e_hidden = encoder(word_first_example)

        decoder_input = torch.cat((e_hidden[0], e_hidden[1]), dim=1)
        decoder_input = decoder_input.unsqueeze(dim=0)

        result = [embeddings.SOS_STR]
        while result[-1] != embeddings.EOS_STR and len(result) <= max_length:
            token_emb = embeddings[result[-1]]
            d_out = decoder(token_emb, decoder_input)
            d_out_lsm = lsm(d_out)
            id = torch.argmax(d_out_lsm)
            token = embeddings.id2word[id]
            result.append(token)
        batch_results.append(' '.join(result))
    all_results.extend(list(zip(words, batch_results)))

    break

AttributeError: 'str' object has no attribute 'size'