Skip to content

Commit

Permalink
Fixbug Beam Search
Browse files Browse the repository at this point in the history
  • Loading branch information
lvapeab committed Jan 11, 2017
1 parent b7ca708 commit aed24df
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 103 deletions.
7 changes: 4 additions & 3 deletions keras_wrapper/beam_search_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, models, dataset, params_prediction, verbose=0):
params_prediction.get('optimized_search') is not None else False
self.verbose = verbose
if self.verbose > 0:
print "Using optimized_search=", self.optimized_search
print "Using optimized_search =", self.optimized_search

# PREDICTION FUNCTIONS: Functions for making prediction on input samples

Expand Down Expand Up @@ -126,7 +126,6 @@ def beam_search(self, X, params, null_sym=2):
prev_outs=prev_outs)
else:
probs = self.predict_cond(self.models, X, state_below, params, ii)

# total score for every sample is sum of -log of word prb
cand_scores = np.array(hyp_scores)[:, None] - np.log(probs)
cand_flat = cand_scores.flatten()
Expand All @@ -141,11 +140,13 @@ def beam_search(self, X, params, null_sym=2):

# Form a beam for the next iteration
new_hyp_samples = []
new_trans_indices = []
new_hyp_scores = np.zeros(k-dead_k).astype('float32')
if params['pos_unk']:
new_hyp_alphas = []
for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)):
new_hyp_samples.append(hyp_samples[ti]+[wi])
new_trans_indices.append(ti)
new_hyp_scores[idx] = copy.copy(costs[idx])
if params['pos_unk']:
new_hyp_alphas.append(hyp_alphas[ti]+[alphas[ti]])
Expand All @@ -164,7 +165,7 @@ def beam_search(self, X, params, null_sym=2):
sample_alphas.append(new_hyp_alphas[idx])
dead_k += 1
else:
indices_alive.append(idx)
indices_alive.append(new_trans_indices[idx])
new_live_k += 1
hyp_samples.append(new_hyp_samples[idx])
hyp_scores.append(new_hyp_scores[idx])
Expand Down

0 comments on commit aed24df

Please sign in to comment.