In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import data
from helpers import Config

In [2]:
SEED = 42
DATA = '/floyd/input/ptb/'
CUDA = False
CONFIG_NAME = 'text_generation_beam_search'
device = torch.device("cuda" if CUDA else "cpu")
args = Config(CONFIG_NAME)

In [3]:
corpus = data.Corpus(DATA)
ntokens = len(corpus.dictionary)

In [4]:
with open(args.checkpoint, 'rb') as f:
    model = torch.load(f, map_location=device)
model.eval()

WD_LSTM(
  (drop): Dropout(p=0.2)
  (encoder): Embedding(10000, 400)
  (rnns): ModuleList(
    (0): WeightDrop(
      (module): LSTM(400, 800)
    )
    (1): WeightDrop(
      (module): LSTM(800, 800)
    )
    (2): WeightDrop(
      (module): LSTM(800, 400)
    )
  )
  (decoder): Linear(in_features=800, out_features=10000, bias=True)
)

In [5]:
def get_next_words(model, word, hidden):
    output, hidden = model(word, hidden)
    word_weights = output.squeeze().cpu()
    probabilities = F.softmax(word_weights, dim=0)
    word_scores, word_indices = probabilities.topk(args.beam_size)
    return word_scores, word_indices, hidden

In [6]:
def run_candidates(words, hidden, p):
    candidates = []
    cue.fill_(corpus.dictionary.word2idx[words[-1]])
    word_scores, word_indices, hidden = get_next_words(model, cue, hidden)
    for ws, wi in zip(word_scores, word_indices):
        if corpus.dictionary.idx2word[wi] != '<unk>':
            candidates.append(
                (words + [corpus.dictionary.idx2word[wi]], hidden, p + np.log(ws.item()))
            )
    return candidates

In [11]:
hidden = model.init_hidden(1)
cue_words = ["the", "federal", "reserve", "board"]
cue = torch.tensor([[corpus.dictionary.word2idx[cue_words[0]]]]).to(device)
for word in cue_words[1:]:
    output, hidden = model(cue, hidden)
    cue.fill_(corpus.dictionary.word2idx[word])

In [14]:
current_candidates = run_candidates(cue_words, hidden, 0)
for _ in range(50):
    new_candidates = []
        new_candidates.extend(run_candidates(cw, h, s))
    new_candidates.sort(key=lambda elem: elem[2], reverse=True)
    current_candidates = new_candidates[:3]
winner = sorted(new_candidates, key=lambda elem: elem[2], reverse=True)[0][0]
print(str.join(" ", winner))

the federal reserve board said it is n't clear that the government would be able to reduce the value of its assets <eos> the new york stock exchange composite trading was quoted at $ N a share down N cents <eos> the company said it expects to post a loss of $ N million or
