In [1]:
import os
import torch
from wdm import LSTMEncoder, LSTMCellDecoder
from embeddings import load_glove_embeddings, Embeddings

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")

In [2]:
dim = 50
data_loc = '../data'
model_loc = '.'

In [None]:
embeddings = load_glove_embeddings(dim, data_loc)
len(embeddings)

In [3]:
def load_encoder(filename):
    encoder = LSTMEncoder(dim, dim)
    with open(filename, 'rb') as f:
        encoder.load_state_dict(torch.load(f))
    return encoder

def load_decoder(filename):
    decoder = LSTMCellDecoder(dim, dim*2, len(embeddings))
    with open(filename, 'rb') as f:
        decoder.load_state_dict(torch.load(f))
    return decoder

In [None]:
encoder = load_encoder(os.path.join(model_loc, 'encoder.pt'))
decoder = load_decoder(os.path.join(model_loc, 'decoder.pt'))


In [4]:
from itertools import chain
from typing import List, Iterable, Callable
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, PackedSequence

def sentence_to_list(sent: str) -> Iterable[str]:
    return chain([Embeddings.SOS_STR], sent.split(), [Embeddings.EOS_STR])

def to_packed_sequence(tensors: List[torch.Tensor]) -> PackedSequence:
    lens = [len(t) for t in tensors]
    packed = pack_padded_sequence(pad_sequence(tensors), lens, enforce_sorted=False).to(device)
    return packed

def strings_to_batch(strings: List[str]) -> PackedSequence:
    sents_emb = [embeddings.sentence_to_tensor(sentence_to_list(sent)) for sent in strings]
    return to_packed_sequence(sents_emb)

def strings_to_ids(strings: List[str]) -> PackedSequence:
    ids = [embeddings.sentence_to_ids(sentence_to_list(sent)) for sent in strings]
    return to_packed_sequence(ids)

