diff --git a/xnmt/__init__.py b/xnmt/__init__.py index 9a8ac9c2d..36373a0d3 100644 --- a/xnmt/__init__.py +++ b/xnmt/__init__.py @@ -58,6 +58,7 @@ import xnmt.persistence import xnmt.rl import xnmt.compound_expr +import xnmt.sentence_stats resolved_serialize_params = {} diff --git a/xnmt/length_normalization.py b/xnmt/length_normalization.py index c84d45813..99390e4c7 100644 --- a/xnmt/length_normalization.py +++ b/xnmt/length_normalization.py @@ -142,28 +142,77 @@ class GaussianNormalization(LengthNormalization, Serializable): yaml_tag = '!GaussianNormalization' @serializable_init - def __init__(self, sent_stats): - self.stats = sent_stats.trg_stat + def __init__(self, sent_stats, length_ratio=False, src_cond=False): + """ + Args: + sent_stats: A SentenceStats object + length_ratio: Instead of fitting len(trg) in the training set, fit the length ratio len(trg) / len(src) + src_cond: Instead of fitting len(trg) in the training set, fit the length of trg with the given len(src) + """ + self.sent_stats = sent_stats self.num_sent = sent_stats.num_pair + self.length_ratio = length_ratio + self.src_cond = src_cond self.fit_distribution() def fit_distribution(self): - y = np.zeros(self.num_sent) - curr_iter = 0 - for key in self.stats: - iter_end = self.stats[key].num_sents + curr_iter - y[curr_iter:iter_end] = key - curr_iter = iter_end - mu, std = norm.fit(y) - self.distr = norm(mu, std) - - def trg_length_prob(self, trg_length): - return self.distr.pdf(trg_length) + if self.length_ratio: + stats = self.sent_stats.src_stat + num_sent = self.sent_stats.num_pair + y = np.zeros(num_sent) + iter = 0 + for key in stats: + for t_len, count in stats[key].trg_len_distribution.items(): + iter_end = count + iter + y[iter:iter_end] = t_len / float(key) + iter = iter_end + mu, std = norm.fit(y) + self.distr = norm(mu, std) + elif self.src_cond: + stats = self.sent_stats.src_stat + self.distr = {} + self.max_key = -1 + for key in stats: + if key > self.max_key: self.max_key = key + num_trg = stats[key].num_sents + y = np.zeros(num_trg) + iter = 0 + for t_len, count in stats[key].trg_len_distribution.items(): + iter_end = count + iter + y[iter:iter_end] = t_len + iter = iter_end + mu, std = norm.fit(y) + if std == 0: std = np.sqrt(key) + self.distr[key] = norm(mu, std) + for i in range(self.max_key-1, -1, -1): + if i not in self.distr: + self.distr[i] = self.distr[i+1] + else: + stats = self.sent_stats.trg_stat + y = np.zeros(self.num_sent) + curr_iter = 0 + for key in stats: + iter_end = stats[key].num_sents + curr_iter + y[curr_iter:iter_end] = key + curr_iter = iter_end + mu, std = norm.fit(y) + self.distr = norm(mu, std) + + def trg_length_prob(self, src_length, trg_length): + if self.length_ratio: + assert (src_length is not None), "Length of Source Sentence is required in GaussianNormalization when length_ratio=True" + return self.distr.pdf(trg_length/src_length) + elif self.src_cond: + if src_length in self.distr: + return self.distr[src_length].pdf(trg_length) + else: + return self.distr[self.max_key].pdf(trg_length) + else: + return self.distr.pdf(trg_length) def normalize_completed(self, completed_hyps: Sequence['search_strategy.BeamSearch.Hypothesis'], src_length: Optional[int] = None) -> Sequence[float]: - return [hyp.score / self.trg_length_prob(len(hyp.id_list)) for hyp in completed_hyps] - + return [hyp.score / self.trg_length_prob(src_len, hyp.length) for hyp, src_len in zip(completed_hyps, src_length)] class EosBooster(Serializable): """ diff --git a/xnmt/search_strategy.py b/xnmt/search_strategy.py index 309b372b2..fb380f93d 100644 --- a/xnmt/search_strategy.py +++ b/xnmt/search_strategy.py @@ -114,7 +114,7 @@ class BeamSearch(Serializable, SearchStrategy): """ yaml_tag = '!BeamSearch' - Hypothesis = namedtuple('Hypothesis', ['score', 'output', 'parent', 'word']) + Hypothesis = namedtuple('Hypothesis', ['score', 'output', 'parent', 'word', 'length']) @serializable_init def __init__(self, beam_size: int = 1, max_len: int = 100, len_norm: LengthNormalization = bare(NoNormalization), @@ -132,7 +132,7 @@ def generate_output(self, translator, initial_state, src_length=None, forced_trg logger.warning("Forced decoding with a target longer than max_len. " "Increase max_len to avoid unexpected behavior.") - active_hyp = [self.Hypothesis(0, None, None, None)] + active_hyp = [self.Hypothesis(0, None, None, None, 0)] completed_hyp = [] for length in range(self.max_len): if len(completed_hyp) >= self.beam_size: @@ -161,7 +161,7 @@ def generate_output(self, translator, initial_state, src_length=None, forced_trg # Queue next states for cur_word in top_words: new_score = self.len_norm.normalize_partial_topk(hyp.score, score[cur_word], length + 1) - new_set.append(self.Hypothesis(new_score, current_output, hyp, cur_word)) + new_set.append(self.Hypothesis(new_score, current_output, hyp, cur_word, hyp.length + 1)) # Next top hypothesis active_hyp = sorted(new_set, key=lambda x: x.score, reverse=True)[:self.beam_size] # There is no hyp reached diff --git a/xnmt/sentence_stats.py b/xnmt/sentence_stats.py index 4ef37eb69..d064bec15 100644 --- a/xnmt/sentence_stats.py +++ b/xnmt/sentence_stats.py @@ -1,13 +1,22 @@ -class SentenceStats(object): +from xnmt.persistence import serializable_init, Serializable + +class SentenceStats(Serializable): """ to Populate the src and trg sents statistics. """ + yaml_tag = '!SentenceStats' - def __init__(self): + @serializable_init + def __init__(self, src_file, trg_file): self.src_stat = {} self.trg_stat = {} self.max_pairs = 1000000 self.num_pair = 0 + training_corpus_src = open(src_file, 'r').readlines() + training_corpus_trg = open(trg_file, 'r').readlines() + training_corpus_src = [i.split() for i in training_corpus_src] + training_corpus_trg = [i.split() for i in training_corpus_trg] + self.populate_statistics(training_corpus_src, training_corpus_trg) class SourceLengthStat: def __init__(self):