Skip to content

Commit

Permalink
Beam search (#38)
Browse files Browse the repository at this point in the history
* clean code

Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>

* Apply tensorboard logging codes

Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>

* Initial commit of beam search

Beam search is a heuristic search algorithm that explores a graph by expanding the most promising node in a limited set. Beam search is an optimization of best-first search that reduces its memory requirements.

Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>

* fix code

Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>

* fix

Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>
  • Loading branch information
normanheckscher authored and hunkim committed Jan 7, 2017
1 parent 1f1c1c0 commit 1422a88
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 10 deletions.
51 changes: 51 additions & 0 deletions beam.py
@@ -0,0 +1,51 @@
import tensorflow as tf
import numpy as np


class BeamSearch():
def __init__(self, probs):
self.probs = probs

def beamsearch(self, oov, empty, eos, k=1, maxsample=4000, use_unk=False):
"""return k samples (beams) and their NLL scores, each sample is a sequence of labels,
all samples starts with an `empty` label and end with `eos` or truncated to length of `maxsample`.
You need to supply `predict` which returns the label probability of each sample.
`use_unk` allow usage of `oov` (out-of-vocabulary) label in samples
"""

dead_k = 0 # samples that reached eos
dead_samples = []
dead_scores = []
live_k = 1 # samples that did not yet reached eos
live_samples = [[empty]]
live_scores = [0]

while live_k and dead_k < k:

# total score for every sample is sum of -log of word prb
cand_scores = np.array(live_scores)[:, None] - np.log(self.probs)
if not use_unk and oov is not None:
cand_scores[:, oov] = 1e20
cand_flat = cand_scores.flatten()

# find the best (lowest) scores we have from all possible samples and new words
ranks_flat = cand_flat.argsort()[:(k - dead_k)]
live_scores = cand_flat[ranks_flat]

# append the new words to their appropriate live sample
voc_size = self.probs.shape[1]
live_samples = [live_samples[r // voc_size] + [r % voc_size] for r in ranks_flat]

# live samples that should be dead are...
zombie = [s[-1] == eos or len(s) >= maxsample for s in live_samples]

# add zombies to the dead
dead_samples += [s for s, z in zip(live_samples, zombie) if z] # remove first label == empty
dead_scores += [s for s, z in zip(live_scores, zombie) if z]
dead_k = len(dead_samples)
# remove zombies from the living
live_samples = [s for s, z in zip(live_samples, zombie) if not z]
live_scores = [s for s, z in zip(live_scores, zombie) if not z]
live_k = len(live_samples)

return dead_samples + live_samples, dead_scores + live_scores
45 changes: 36 additions & 9 deletions model.py
Expand Up @@ -4,6 +4,8 @@
import random
import numpy as np

from beam import BeamSearch

class Model():
def __init__(self, args, infer=False):
self.args = args
Expand Down Expand Up @@ -33,9 +35,23 @@ def __init__(self, args, infer=False):
self.batch_time = tf.Variable(0.0, name="batch_time", trainable=False)
tf.summary.scalar("time_batch", self.batch_time)

def variable_summaries(var):
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
with tf.name_scope('summaries'):
mean = tf.reduce_mean(var)
tf.summary.scalar('mean', mean)
#with tf.name_scope('stddev'):
# stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
#tf.summary.scalar('stddev', stddev)
tf.summary.scalar('max', tf.reduce_max(var))
tf.summary.scalar('min', tf.reduce_min(var))
#tf.summary.histogram('histogram', var)

with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
variable_summaries(softmax_w)
softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
variable_summaries(softmax_b)
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data))
Expand Down Expand Up @@ -64,7 +80,7 @@ def loop(prev, _):
optimizer = tf.train.AdamOptimizer(self.lr)
self.train_op = optimizer.apply_gradients(zip(grads, tvars))

def sample(self, sess, words, vocab, num=200, prime='first all', sampling_type=1):
def sample(self, sess, words, vocab, num=200, prime='first all', sampling_type=1, pick=0):
state = sess.run(self.cell.zero_state(1, tf.float32))
if not len(prime) or prime == " ":
prime = random.choice(list(vocab.keys()))
Expand All @@ -81,6 +97,14 @@ def weighted_pick(weights):
s = np.sum(weights)
return(int(np.searchsorted(t, np.random.rand(1)*s)))

def beam_search_pick(weights):
probs[0] = weights
samples, scores = BeamSearch(probs).beamsearch(None, vocab.get(prime), None, 2, len(weights), False)
sampleweights = samples[np.argmax(scores)]
t = np.cumsum(sampleweights)
s = np.sum(sampleweights)
return(int(np.searchsorted(t, np.random.rand(1)*s)))

ret = prime
word = prime.split()[-1]
for n in range(num):
Expand All @@ -90,15 +114,18 @@ def weighted_pick(weights):
[probs, state] = sess.run([self.probs, self.final_state], feed)
p = probs[0]

if sampling_type == 0:
sample = np.argmax(p)
elif sampling_type == 2:
if word == '\n':
sample = weighted_pick(p)
else:
if pick == 1:
if sampling_type == 0:
sample = np.argmax(p)
else: # sampling_type == 1 default:
sample = weighted_pick(p)
elif sampling_type == 2:
if word == '\n':
sample = weighted_pick(p)
else:
sample = np.argmax(p)
else: # sampling_type == 1 default:
sample = weighted_pick(p)
elif pick == 2:
sample = beam_search_pick(p)

pred = words[sample]
ret += ' ' + pred
Expand Down
4 changes: 3 additions & 1 deletion sample.py
Expand Up @@ -18,6 +18,8 @@ def main():
help='number of words to sample')
parser.add_argument('--prime', type=str, default=' ',
help='prime text')
parser.add_argument('--pick', type=int, default=1,
help='1 = weighted pick, 2 = beam search pick')
parser.add_argument('--sample', type=int, default=1,
help='0 to use max at each timestep, 1 to sample at each timestep, 2 to sample on spaces')

Expand All @@ -36,7 +38,7 @@ def sample(args):
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print(model.sample(sess, words, vocab, args.n, args.prime, args.sample))
print(model.sample(sess, words, vocab, args.n, args.prime, args.sample, args.pick))

if __name__ == '__main__':
main()

0 comments on commit 1422a88

Please sign in to comment.