Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding features to Gaussian Normalization #479

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xnmt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import xnmt.persistence
import xnmt.rl
import xnmt.compound_expr
import xnmt.sentence_stats

resolved_serialize_params = {}

Expand Down
79 changes: 64 additions & 15 deletions xnmt/length_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,28 +142,77 @@ class GaussianNormalization(LengthNormalization, Serializable):
yaml_tag = '!GaussianNormalization'

@serializable_init
def __init__(self, sent_stats):
self.stats = sent_stats.trg_stat
def __init__(self, sent_stats, length_ratio=False, src_cond=False):
"""
Args:
sent_stats: A SentenceStats object
length_ratio: Instead of fitting len(trg) in the training set, fit the length ratio len(trg) / len(src)
src_cond: Instead of fitting len(trg) in the training set, fit the length of trg with the given len(src)
"""
self.sent_stats = sent_stats
self.num_sent = sent_stats.num_pair
self.length_ratio = length_ratio
self.src_cond = src_cond
self.fit_distribution()

def fit_distribution(self):
y = np.zeros(self.num_sent)
curr_iter = 0
for key in self.stats:
iter_end = self.stats[key].num_sents + curr_iter
y[curr_iter:iter_end] = key
curr_iter = iter_end
mu, std = norm.fit(y)
self.distr = norm(mu, std)

def trg_length_prob(self, trg_length):
return self.distr.pdf(trg_length)
if self.length_ratio:
stats = self.sent_stats.src_stat
num_sent = self.sent_stats.num_pair
y = np.zeros(num_sent)
iter = 0
for key in stats:
for t_len, count in stats[key].trg_len_distribution.items():
iter_end = count + iter
y[iter:iter_end] = t_len / float(key)
iter = iter_end
mu, std = norm.fit(y)
self.distr = norm(mu, std)
elif self.src_cond:
stats = self.sent_stats.src_stat
self.distr = {}
self.max_key = -1
for key in stats:
if key > self.max_key: self.max_key = key
num_trg = stats[key].num_sents
y = np.zeros(num_trg)
iter = 0
for t_len, count in stats[key].trg_len_distribution.items():
iter_end = count + iter
y[iter:iter_end] = t_len
iter = iter_end
mu, std = norm.fit(y)
if std == 0: std = np.sqrt(key)
self.distr[key] = norm(mu, std)
for i in range(self.max_key-1, -1, -1):
if i not in self.distr:
self.distr[i] = self.distr[i+1]
else:
stats = self.sent_stats.trg_stat
y = np.zeros(self.num_sent)
curr_iter = 0
for key in stats:
iter_end = stats[key].num_sents + curr_iter
y[curr_iter:iter_end] = key
curr_iter = iter_end
mu, std = norm.fit(y)
self.distr = norm(mu, std)

def trg_length_prob(self, src_length, trg_length):
if self.length_ratio:
assert (src_length is not None), "Length of Source Sentence is required in GaussianNormalization when length_ratio=True"
return self.distr.pdf(trg_length/src_length)
elif self.src_cond:
if src_length in self.distr:
return self.distr[src_length].pdf(trg_length)
else:
return self.distr[self.max_key].pdf(trg_length)
else:
return self.distr.pdf(trg_length)

def normalize_completed(self, completed_hyps: Sequence['search_strategy.BeamSearch.Hypothesis'],
src_length: Optional[int] = None) -> Sequence[float]:
return [hyp.score / self.trg_length_prob(len(hyp.id_list)) for hyp in completed_hyps]

return [hyp.score / self.trg_length_prob(src_len, hyp.length) for hyp, src_len in zip(completed_hyps, src_length)]

class EosBooster(Serializable):
"""
Expand Down
6 changes: 3 additions & 3 deletions xnmt/search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class BeamSearch(Serializable, SearchStrategy):
"""

yaml_tag = '!BeamSearch'
Hypothesis = namedtuple('Hypothesis', ['score', 'output', 'parent', 'word'])
Hypothesis = namedtuple('Hypothesis', ['score', 'output', 'parent', 'word', 'length'])

@serializable_init
def __init__(self, beam_size: int = 1, max_len: int = 100, len_norm: LengthNormalization = bare(NoNormalization),
Expand All @@ -132,7 +132,7 @@ def generate_output(self, translator, initial_state, src_length=None, forced_trg
logger.warning("Forced decoding with a target longer than max_len. "
"Increase max_len to avoid unexpected behavior.")

active_hyp = [self.Hypothesis(0, None, None, None)]
active_hyp = [self.Hypothesis(0, None, None, None, 0)]
completed_hyp = []
for length in range(self.max_len):
if len(completed_hyp) >= self.beam_size:
Expand Down Expand Up @@ -161,7 +161,7 @@ def generate_output(self, translator, initial_state, src_length=None, forced_trg
# Queue next states
for cur_word in top_words:
new_score = self.len_norm.normalize_partial_topk(hyp.score, score[cur_word], length + 1)
new_set.append(self.Hypothesis(new_score, current_output, hyp, cur_word))
new_set.append(self.Hypothesis(new_score, current_output, hyp, cur_word, hyp.length + 1))
# Next top hypothesis
active_hyp = sorted(new_set, key=lambda x: x.score, reverse=True)[:self.beam_size]
# There is no hyp reached </s>
Expand Down
13 changes: 11 additions & 2 deletions xnmt/sentence_stats.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
class SentenceStats(object):
from xnmt.persistence import serializable_init, Serializable

class SentenceStats(Serializable):
"""
to Populate the src and trg sents statistics.
"""
yaml_tag = '!SentenceStats'

def __init__(self):
@serializable_init
def __init__(self, src_file, trg_file):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document the arguments? Another thing to consider is that this design will cause the statistics to be recomputed even when loading a saved model. If this takes some amount of time or is otherwise inconvenient (e.g. because it requires keeping the files around at the same location), it would be possible to do something similar to the vocabs that only load the vocab from file if the i2w argument is not set, and then use saved_processed_arg() to store the result so that the vocab file will not need to be opened when loading the model: https://github.com/neulab/xnmt/blob/master/xnmt/vocab.py#L25

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @msperber here. @cindyxinyiwang, could you take a look at it?

self.src_stat = {}
self.trg_stat = {}
self.max_pairs = 1000000
self.num_pair = 0
training_corpus_src = open(src_file, 'r').readlines()
training_corpus_trg = open(trg_file, 'r').readlines()
training_corpus_src = [i.split() for i in training_corpus_src]
training_corpus_trg = [i.split() for i in training_corpus_trg]
self.populate_statistics(training_corpus_src, training_corpus_trg)

class SourceLengthStat:
def __init__(self):
Expand Down