In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import data
from helpers import Config

In [2]:
SEED = 42
DATA = '/floyd/home/sein/'
CUDA = False
SAMPLE = True
CONFIG_NAME = 'text_generation_beam_search'
device = torch.device("cuda" if CUDA else "cpu")
args = Config(CONFIG_NAME)

In [3]:
corpus = data.Corpus(DATA)
ntokens = len(corpus.dictionary)

In [4]:
with open(args.checkpoint, 'rb') as f:
    model = torch.load(f, map_location=device)
model.eval()

WD_LSTM(
  (variational_dropout): VariationalDropout()
  (encoder): Embedding(18818, 400)
  (rnns): ModuleList(
    (0): WeightDrop(
      (module): LSTM(400, 1150)
    )
    (1): WeightDrop(
      (module): LSTM(1150, 1150)
    )
    (2): WeightDrop(
      (module): LSTM(1150, 400)
    )
  )
  (decoder): Linear(in_features=400, out_features=18818, bias=True)
)

In [5]:
def get_next_words(model, word, hidden):
    output, hidden, _, _ = model(word, hidden)
    word_weights = output.squeeze().cpu()
    probabilities = F.softmax(word_weights, dim=0)
    word_scores, word_indices = probabilities.topk(args.beam_size)
    return word_scores, word_indices, hidden

In [6]:
def run_candidates(words, hidden, p):
    candidates = []
    cue.fill_(corpus.dictionary.word2idx[words[-1]])
    word_scores, word_indices, hidden = get_next_words(model, cue, hidden)
    for ws, wi in zip(word_scores, word_indices):
        if corpus.dictionary.idx2word[wi] != '<unk>':
            candidates.append(
                (words + [corpus.dictionary.idx2word[wi]], hidden, p + np.log(ws.item()))
            )
    return candidates

In [9]:
hidden = model.init_hidden(1)
cue_words = ["george", ":", "will"]
cue = torch.tensor([[corpus.dictionary.word2idx[cue_words[0]]]]).to(device)
for word in cue_words[1:]:
    output, hidden, _, _ = model(cue, hidden)
    cue.fill_(corpus.dictionary.word2idx[word])

In [10]:
current_candidates = run_candidates(cue_words, hidden, 0)
for _ in range(args.words):
    new_candidates = []
    for cw, h, s in current_candidates:
        new_candidates.extend(run_candidates(cw, h, s))
    new_candidates.sort(key=lambda elem: elem[2], reverse=True)
    current_candidates = new_candidates[:args.beam_size]
winner = sorted(new_candidates, key=lambda elem: elem[2], reverse=True)[0][0]
print(str.join(" ", winner))

george : will you stop it ? <eos> jerry : i don ' t know . <eos> jerry : i don ' t know . <eos> jerry : i don ' t know . <eos> jerry : i don ' t know . <eos> jerry : i don ' t know . <eos> jerry
