Skip to content

Commit

Permalink
Boost </s> (#386)
Browse files Browse the repository at this point in the history
* implement boost_eos

* make boost-eos optional

* moved boost_eos to length normalizers

* dedicated class for eos boosting

* remove outdated code

* fix unit tests
  • Loading branch information
msperber authored and neubig committed May 25, 2018
1 parent bc8978c commit a7d1338
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 53 deletions.
94 changes: 56 additions & 38 deletions xnmt/length_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

from xnmt.persistence import serializable_init, Serializable
from xnmt import search_strategy
from xnmt.vocab import Vocab

class LengthNormalization(object):
'''
A template class to generate translation from the output probability model.
'''
def normalize_completed(self, completed_hyps:Sequence['search_strategy.BeamSearch.Hypothesis'], src_length:Optional[int]=None) \
-> Sequence[float]:
"""
A template class to adjust scores for length normalization during search.
"""

def normalize_completed(self, completed_hyps: Sequence['search_strategy.BeamSearch.Hypothesis'],
src_length: Optional[int] = None) -> Sequence[float]:
"""
Apply normalization step to completed hypotheses after search and return the normalized scores.
Expand All @@ -23,72 +25,71 @@ def normalize_completed(self, completed_hyps:Sequence['search_strategy.BeamSearc
normalized scores
"""
raise NotImplementedError('normalize_completed must be implemented in LengthNormalization subclasses')
def normalize_partial(self, score_so_far, score_to_add, new_len):

def normalize_partial_topk(self, score_so_far, score_to_add, new_len):
"""
Apply normalization step after expanding a partial hypothesis and selecting the top k scores.
Args:
score_so_far:
score_to_add:
new_len: length of output hyp with current word already appended
score_so_far: log score of the partial hypothesis
score_to_add: log score of the top-k item that is to be added
new_len: new length of partial hypothesis with current word already appended
Returns:
new score after applying score_to_add to score_so_far
normalization step applied during the search
"""
return score_so_far + score_to_add # default behavior: add up the log probs


class NoNormalization(LengthNormalization, Serializable):
'''
Adding no form of length normalization
'''
"""
Adding no form of length normalization.
"""
yaml_tag = '!NoNormalization'

@serializable_init
def __init__(self):
pass

def normalize_completed(self, completed_hyps:Sequence['search_strategy.BeamSearch.Hypothesis'], src_length:Optional[int]=None) \
-> Sequence[float]:
def normalize_completed(self, completed_hyps: Sequence['search_strategy.BeamSearch.Hypothesis'],
src_length: Optional[int] = None) -> Sequence[float]:
return [hyp.score for hyp in completed_hyps]

class AdditiveNormalization(LengthNormalization, Serializable):
'''
"""
Adding a fixed word penalty everytime the word is added.
'''
"""
yaml_tag = '!AdditiveNormalization'

@serializable_init
def __init__(self, penalty:Real=-0.1, apply_during_search:bool=False):
def __init__(self, penalty: Real = -0.1, apply_during_search: bool = False):
self.penalty = penalty
self.apply_during_search = apply_during_search

def normalize_completed(self, completed_hyps:Sequence['search_strategy.BeamSearch.Hypothesis'], src_length:Optional[int]=None) \
-> Sequence[float]:
def normalize_completed(self, completed_hyps: Sequence['search_strategy.BeamSearch.Hypothesis'],
src_length: Optional[int] = None) -> Sequence[float]:
if self.apply_during_search:
return [hyp.score for hyp in completed_hyps]
else:
return [hyp.score + (len(hyp.id_list) * self.penalty) for hyp in completed_hyps]
def normalize_partial(self, score_so_far, score_to_add, new_len):
def normalize_partial_topk(self, score_so_far, score_to_add, new_len):
return score_so_far + score_to_add + (self.penalty if self.apply_during_search else 0.0)


class PolynomialNormalization(LengthNormalization, Serializable):
'''
"""
Dividing by the length (raised to some power)
'''
"""
yaml_tag = '!PolynomialNormalization'

@serializable_init
def __init__(self, m:Real=1, apply_during_search:bool=False):
def __init__(self, m: Real = 1, apply_during_search: bool = False):
self.m = m
self.apply_during_search = apply_during_search
self.pows = []

def normalize_completed(self, completed_hyps:Sequence['search_strategy.BeamSearch.Hypothesis'], src_length:Optional[int]=None) \
-> Sequence[float]:
def normalize_completed(self, completed_hyps: Sequence['search_strategy.BeamSearch.Hypothesis'],
src_length: Optional[int] = None) -> Sequence[float]:
if self.apply_during_search:
return [hyp.score for hyp in completed_hyps]
else:
return [(hyp.score / pow(len(hyp.output.word_ids), self.m)) for hyp in completed_hyps]
def normalize_partial(self, score_so_far, score_to_add, new_len):
def normalize_partial_topk(self, score_so_far, score_to_add, new_len):
if self.apply_during_search:
self.update_pows(new_len)
return (score_so_far * self.pows[new_len-1] + score_to_add) / self.pows[new_len]
Expand All @@ -101,11 +102,11 @@ def update_pows(self, new_len):


class MultinomialNormalization(LengthNormalization, Serializable):
'''
"""
The algorithm followed by:
Tree-to-Sequence Attentional Neural Machine Translation
https://arxiv.org/pdf/1603.06075.pdf
'''
"""
yaml_tag = '!MultinomialNormalization'

@serializable_init
Expand All @@ -119,7 +120,8 @@ def trg_length_prob(self, src_length, trg_length):
return (src_stat.trg_len_distribution.get(trg_length, 0) + 1) / (src_stat.num_sents + v)
return 1

def normalize_completed(self, completed_hyps, src_length=None):
def normalize_completed(self, completed_hyps: Sequence['search_strategy.BeamSearch.Hypothesis'],
src_length: Optional[int] = None) -> Sequence[float]:
"""
Args:
completed_hyps:
Expand All @@ -131,12 +133,12 @@ def normalize_completed(self, completed_hyps, src_length=None):


class GaussianNormalization(LengthNormalization, Serializable):
'''
"""
The Gaussian regularization encourages the inference
to select sents that have similar lengths as the
sents in the training set.
refer: https://arxiv.org/pdf/1509.04942.pdf
'''
"""
yaml_tag = '!GaussianNormalization'

@serializable_init
Expand All @@ -158,6 +160,22 @@ def fit_distribution(self):
def trg_length_prob(self, trg_length):
return self.distr.pdf(trg_length)

def normalize_completed(self, completed_hyps:Sequence['search_strategy.BeamSearch.Hypothesis'], src_length:Optional[int]=None) \
-> Sequence[float]:
def normalize_completed(self, completed_hyps: Sequence['search_strategy.BeamSearch.Hypothesis'],
src_length: Optional[int] = None) -> Sequence[float]:
return [hyp.score / self.trg_length_prob(len(hyp.id_list)) for hyp in completed_hyps]


class EosBooster(Serializable):
"""
Callable that applies boosting of end-of-sequence token, can be used with :class:`xnmt.search_strategy.BeamSearch`.
Args:
boost_val: value to add to the eos token's log probability. Positive values make sentences shorter, negative values
make sentences longer.
"""
yaml_tag = "!EosBooster"
@serializable_init
def __init__(self, boost_val: float):
self.boost_val = boost_val
def __call__(self, scores:np.ndarray) -> None:
scores[Vocab.ES] += self.boost_val
36 changes: 21 additions & 15 deletions xnmt/search_strategy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import namedtuple
import math
from typing import Optional, Callable

import dynet as dy
import numpy as np

import xnmt.batcher
from xnmt.length_normalization import NoNormalization
from xnmt.length_normalization import NoNormalization, LengthNormalization
from xnmt.persistence import Serializable, serializable_init, bare
from xnmt.vocab import Vocab

Expand All @@ -21,9 +22,9 @@
SearchOutput = namedtuple('SearchOutput', ['word_ids', 'attentions', 'score', 'logsoftmaxes', 'state', 'mask'])

class SearchStrategy(object):
'''
"""
A template class to generate translation from the output probability model. (Non-batched operation)
'''
"""
def generate_output(self, translator, dec_state,
src_length=None, forced_trg_ids=None):
"""
Expand All @@ -38,12 +39,12 @@ def generate_output(self, translator, dec_state,
raise NotImplementedError('generate_output must be implemented in SearchStrategy subclasses')

class GreedySearch(Serializable, SearchStrategy):
'''
"""
Performs greedy search (aka beam search with beam size 1)
Args:
max_len (int): maximum number of tokens to generate.
'''
"""

yaml_tag = '!GreedySearch'

Expand All @@ -63,7 +64,6 @@ def generate_output(self, translator, initial_state,
# Search Variables
done = None
current_state = initial_state
current_output = None
for length in range(self.max_len):
prev_word = word_ids[length-1] if length > 0 else None
current_output = translator.output_one_step(prev_word, current_state)
Expand Down Expand Up @@ -104,21 +104,25 @@ class BeamSearch(Serializable, SearchStrategy):
Performs beam search.
Args:
beam_size (int):
max_len (int): maximum number of tokens to generate.
len_norm (LengthNormalization): type of length normalization to apply
one_best (bool): Whether to output the best hyp only or all completed hyps.
beam_size: number of beams
max_len: maximum number of tokens to generate.
len_norm: type of length normalization to apply
one_best: Whether to output the best hyp only or all completed hyps.
scores_proc: apply an optional operation on all scores prior to choosing the top k.
E.g. use with :class:`xnmt.length_normalization.EosBooster`.
"""

yaml_tag = '!BeamSearch'
Hypothesis = namedtuple('Hypothesis', ['score', 'output', 'parent', 'word'])

@serializable_init
def __init__(self, beam_size=1, max_len=100, len_norm=bare(NoNormalization), one_best=True):
def __init__(self, beam_size: int = 1, max_len: int = 100, len_norm: LengthNormalization = bare(NoNormalization),
one_best: bool = True, scores_proc: Optional[Callable[[np.ndarray], None]] = None):
self.beam_size = beam_size
self.max_len = max_len
self.len_norm = len_norm
self.one_best = one_best
self.scores_proc = scores_proc

def generate_output(self, translator, initial_state, src_length=None, forced_trg_ids=None):
# TODO(philip30): can only do single decoding, not batched
Expand All @@ -142,14 +146,16 @@ def generate_output(self, translator, initial_state, src_length=None, forced_trg
continue
current_output = translator.output_one_step(prev_word, prev_state)
score = current_output.logsoftmax.npvalue().transpose()
if self.scores_proc:
self.scores_proc(score)
# Next Words
if forced_trg_ids is None:
top_words = np.argpartition(score, max(-len(score),-self.beam_size))[-self.beam_size:]
else:
top_words = [forced_trg_ids[length]]
# Queue next states
for cur_word in top_words:
new_score = self.len_norm.normalize_partial(hyp.score, score[cur_word], length+1)
new_score = self.len_norm.normalize_partial_topk(hyp.score, score[cur_word], length + 1)
new_set.append(self.Hypothesis(new_score, current_output, hyp, cur_word))
# Next top hypothesis
active_hyp = sorted(new_set, key=lambda x: x.score, reverse=True)[:self.beam_size]
Expand Down Expand Up @@ -259,7 +265,7 @@ def sample_one(self, translator, initial_state, forced_trg_ids=None):
return SearchOutput(samples, attentions, scores, logsofts, states, masks)


class MctsNode:
class MctsNode(object):
def __init__(self, parent, prior_dist, word, attention, translator, dec_state):
self.parent = parent
self.prior_dist = prior_dist # log of softmax
Expand Down Expand Up @@ -377,9 +383,9 @@ def greedy_choice(logsoftmax):


class MctsSearch(Serializable, SearchStrategy):
'''
"""
Performs search with Monte Carlo Tree Search
'''
"""
yaml_tag = '!MctsSearch'

@serializable_init
Expand Down

0 comments on commit a7d1338

Please sign in to comment.