Skip to content

Commit

Permalink
Merge pull request #512 from neulab/fix_output
Browse files Browse the repository at this point in the history
Fix output
  • Loading branch information
philip30 committed Aug 10, 2018
2 parents 128c647 + fecb551 commit e1cb477
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 45 deletions.
26 changes: 20 additions & 6 deletions xnmt/input_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,23 @@ def vocab_size(self):
return len(self.vocab)

class CharFromWordTextReader(PlainTextReader, Serializable):
# TODO @philip30
# - add documentation
# - possibly represent as list of list
"""
Read in word based corpus and turned that into SegmentedSentence.
SegmentedSentece's words are characters, but it contains the information of the segmentation.
x = SegmentedSentence("i code today")
(TRUE) x.words == ["i", "c", "o", "d", "e", "t", "o", "d", "a", "y"]
(TRUE) x.segment == [0, 4, 9]
It means that the segmentation (end of words) happen in the 0th, 4th and 9th position of the char sequence.
"""
yaml_tag = "!CharFromWordTextReader"
@serializable_init
def __init__(self, vocab:Vocab=None, read_sent_len:bool=False):
super().__init__(vocab, read_sent_len)
def __init__(self, vocab:Vocab=None, read_sent_len:bool=False, output_proc=[]):
self.vocab = vocab
self.read_sent_len = read_sent_len
self.output_procs = output.OutputProcessor.get_output_processor(output_proc)

def read_sent(self, line, idx):
chars = []
segs = []
Expand All @@ -242,7 +252,11 @@ def read_sent(self, line, idx):
chars.extend([c for c in word])
segs.append(len(chars))
chars.append(Vocab.ES_STR)
sent_input = SegmentedSentence(segment=segs, words=[self.vocab.convert(c) for c in chars], idx=idx)
sent_input = SegmentedSentence(segment=segs,
words=[self.vocab.convert(c) for c in chars],
idx=idx,
vocab=self.vocab,
output_procs=self.output_procs)
return sent_input

class H5Reader(InputReader, Serializable):
Expand Down
17 changes: 11 additions & 6 deletions xnmt/loss_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class GlobalFertilityLoss(Serializable, LossCalculator):
"""
yaml_tag = '!GlobalFertilityLoss'
@serializable_init
def __init__(self, weight:float = 1) -> None:
self.weight = weight
def __init__(self) -> None:
pass

def calc_loss(self,
model: 'model_base.ConditionedModel',
Expand All @@ -71,7 +71,7 @@ def calc_loss(self,
masked_attn = [dy.cmult(attn, dy.inputTensor(mask, batched=True)) for attn, mask in zip(masked_attn, trg_mask)]

loss = self.global_fertility(masked_attn)
return FactoredLossExpr({"global_fertility": self.weight * loss})
return FactoredLossExpr({"global_fertility": loss})

def global_fertility(self, a):
return dy.sum_elems(dy.square(1 - dy.esum(a)))
Expand All @@ -82,16 +82,21 @@ class CompositeLoss(Serializable, LossCalculator):
"""
yaml_tag = "!CompositeLoss"
@serializable_init
def __init__(self, losses:List[LossCalculator]):
def __init__(self, losses:List[LossCalculator], loss_weight=None):
self.losses = losses
if loss_weight is None:
self.loss_weight = [1.0 for _ in range(len(losses))]
else:
self.loss_weight = loss_weight
assert len(self.loss_weight) == len(losses)

def calc_loss(self,
model: 'model_base.ConditionedModel',
src: Union[sent.Sentence, 'batchers.Batch'],
trg: Union[sent.Sentence, 'batchers.Batch']):
total_loss = FactoredLossExpr()
for loss in self.losses:
total_loss.add_factored_loss_expr(loss.calc_loss(model, src, trg))
for loss, weight in zip(self.losses, self.loss_weight):
total_loss.add_factored_loss_expr(loss.calc_loss(model, src, trg) * weight)
return total_loss

class ReinforceLoss(Serializable, LossCalculator):
Expand Down
2 changes: 2 additions & 0 deletions xnmt/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def get_nobackprop_loss(self):
def __len__(self):
return len(self.expr_factors)

def __mul__(self, scalar):
return FactoredLossExpr({key: scalar*value for key, value in self.expr_factors.items()})

class FactoredLossVal(object):

Expand Down
3 changes: 2 additions & 1 deletion xnmt/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ def __init__(self, space_token="\u2581"):
self.space_token = space_token

def process(self, s: str) -> str:
return s.replace(self.space_token, " ").strip()
return s.replace(" ", "").replace(self.space_token, " ").strip()

43 changes: 11 additions & 32 deletions xnmt/sent.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,25 +201,12 @@ def len_unpadded(self):
def create_padded_sent(self, pad_len: int) -> 'SimpleSentence':
if pad_len == 0:
return self
# Copy is used to copy all possible annotations
new_words = self.words + [self.pad_token] * pad_len
return SimpleSentence(words=new_words,
idx=self.idx,
vocab=self.vocab,
score=self.score,
output_procs=self.output_procs,
pad_token=self.pad_token)
return self.sent_with_new_words(self.words + [self.pad_token] * pad_len)

def create_truncated_sent(self, trunc_len: int) -> 'SimpleSentence':
if trunc_len == 0:
return self
new_words = self.words[:-trunc_len]
return SimpleSentence(words=new_words,
idx=self.idx,
vocab=self.vocab,
score=self.score,
output_procs=self.output_procs,
pad_token=self.pad_token)
return self.sent_with_words(self.words[:-trunc_len])

def str_tokens(self, exclude_ss_es=True, exclude_unk=False, exclude_padded=True, **kwargs) -> List[str]:
exclude_set = set()
Expand All @@ -232,28 +219,20 @@ def str_tokens(self, exclude_ss_es=True, exclude_unk=False, exclude_padded=True,
if self.vocab: return [self.vocab[w] for w in ret_toks]
else: return [str(w) for w in ret_toks]

def sent_with_new_words(self, new_words):
return SimpleSentence(words=new_words,
idx=self.idx,
vocab=self.vocab,
score=self.score,
output_procs=self.output_procs,
pad_token=self.pad_token)

class SegmentedSentence(SimpleSentence):
def __init__(self, segment=[], **kwargs) -> None:
super().__init__(**kwargs)
self.segment = segment

def create_padded_sent(self, pad_len: int) -> 'SimpleSentence':
if pad_len == 0:
return self
# Copy is used to copy all possible annotations
new_words = self.words + [self.pad_token] * pad_len
return SegmentedSentence(words=new_words,
idx=self.idx,
vocab=self.vocab,
score=self.score,
output_procs=self.output_procs,
pad_token=self.pad_token,
segment=self.segment)

def create_truncated_sent(self, trunc_len: int) -> 'SimpleSentence':
if trunc_len == 0:
return self
new_words = self.words[:-trunc_len]
def sent_with_new_words(self, new_words):
return SegmentedSentence(words=new_words,
idx=self.idx,
vocab=self.vocab,
Expand Down

0 comments on commit e1cb477

Please sign in to comment.