Skip to content

Commit

Permalink
Fix type annotations and minor issues (#515)
Browse files Browse the repository at this point in the history
* fix some type annotations

* add missing range

* translator: check for compound inputs

* trigger event first

* remove extra event trigger

* another try to fix event trigger
  • Loading branch information
msperber committed Aug 13, 2018
1 parent 74ea9eb commit 38f5089
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
2 changes: 1 addition & 1 deletion xnmt/batchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def sent_len(self):
return sum(b.sent_len() for b in self.batches)

def __iter__(self):
for i in self.batch_size():
for i in range(self.batch_size()):
yield sent.CompoundSentence(sents=[b[i] for b in self.batches])

def __getitem__(self, key):
Expand Down
11 changes: 8 additions & 3 deletions xnmt/eval/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ class LossEvalTask(EvalTask, Serializable):
yaml_tag = '!LossEvalTask'

@serializable_init
def __init__(self, src_file: str, ref_file: Optional[str] = None, model: 'model_base.GeneratorModel' = Ref("model"),
def __init__(self,
src_file: Union[str, Sequence[str]],
ref_file: Optional[str] = None,
model: 'model_base.GeneratorModel' = Ref("model"),
batcher: Batcher = Ref("train.batcher", default=bare(xnmt.batchers.SrcBatcher, batch_size=32)),
loss_calculator: LossCalculator = bare(MLELoss), max_src_len: Optional[int] = None,
loss_calculator: LossCalculator = bare(MLELoss),
max_src_len: Optional[int] = None,
max_trg_len: Optional[int] = None,
loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"), desc: Any = None):
loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"),
desc: Any = None):
self.model = model
self.loss_calculator = loss_calculator
self.src_file = src_file
Expand Down
4 changes: 2 additions & 2 deletions xnmt/experiments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Dict
from typing import Any, Dict, List, Optional

from xnmt.param_initializers import ParamInitializer, GlorotInitializer, ZeroInitializer
from xnmt.settings import settings
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self,
loss_comb_method: str = "sum",
compute_report: bool = False,
commandline_args: dict = {},
placeholders: Dict[str, str] = {}) -> None:
placeholders: Dict[str, Any] = {}) -> None:
self.model_file = model_file
self.log_file = log_file
self.dropout = dropout
Expand Down
7 changes: 5 additions & 2 deletions xnmt/models/translators.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def shared_params(self):


def _encode_src(self, src: Union[batchers.Batch, sent.Sentence]):
event_trigger.start_sent(src)
embeddings = self.src_embedder.embed_sent(src)
encoding = self.encoder.transduce(embeddings)
final_state = self.encoder.get_final_states()
Expand All @@ -129,6 +128,8 @@ def _encode_src(self, src: Union[batchers.Batch, sent.Sentence]):
return initial_state

def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) -> dy.Expression:
event_trigger.start_sent(src)
if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
# Encode the sentence
initial_state = self._encode_src(src)

Expand Down Expand Up @@ -201,8 +202,9 @@ def generate_search_output(self,
if src.batch_size()!=1:
raise NotImplementedError("batched decoding not implemented for DefaultTranslator. "
"Specify inference batcher with batch size 1.")
# Generating outputs
event_trigger.start_sent(src)
if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
# Generating outputs
cur_forced_trg = None
src_sent = src[0]
sent_mask = None
Expand Down Expand Up @@ -233,6 +235,7 @@ def generate(self,
"""
assert src.batch_size() == 1
search_outputs = self.generate_search_output(src, search_strategy, forced_trg_ids)
if isinstance(src, batchers.CompoundBatch): src = src.batches[0]
sorted_outputs = sorted(search_outputs, key=lambda x: x.score[0], reverse=True)
assert len(sorted_outputs) >= 1
outputs = []
Expand Down

0 comments on commit 38f5089

Please sign in to comment.