Skip to content

Commit

Permalink
fix models base interface (#509)
Browse files Browse the repository at this point in the history
* fixed models base interface

* update translator

* remove get_primary_loss
  • Loading branch information
msperber authored and neubig committed Aug 16, 2018
1 parent 38f5089 commit 64b37a1
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 32 deletions.
11 changes: 4 additions & 7 deletions xnmt/eval/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,10 @@ def eval(self) -> 'EvalScore':

loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()}

try:
return LossScore(loss_stats[self.model.get_primary_loss()],
loss_stats=loss_stats,
num_ref_words = ref_words_cnt,
desc=self.desc)
except KeyError:
raise RuntimeError("Did you wrap your loss calculation with FactoredLossExpr({'primary_loss': loss_value}) ?")
return LossScore(sum(loss_stats.values()),
loss_stats=loss_stats,
num_ref_words = ref_words_cnt,
desc=self.desc)

class AccuracyEvalTask(EvalTask, reports.Reportable, Serializable):
"""
Expand Down
17 changes: 6 additions & 11 deletions xnmt/models/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional, Sequence, Union

import dynet as dy

from xnmt import batchers, input_readers, losses, sent
from xnmt import event_trigger, loss_calculators
from xnmt.persistence import Serializable, serializable_init
Expand All @@ -9,7 +11,7 @@ class TrainableModel(object):
A template class for a basic trainable model, implementing a loss function.
"""

def calc_nll(self, *args, **kwargs) -> losses.FactoredLossExpr:
def calc_nll(self, *args, **kwargs) -> dy.Expression:
"""Calculate loss based on input-output pairs.
Losses are accumulated only across unmasked timesteps in each batch element.
Expand All @@ -20,13 +22,6 @@ def calc_nll(self, *args, **kwargs) -> losses.FactoredLossExpr:
A (possibly batched) expression representing the loss.
"""

def get_primary_loss(self) -> str:
"""
Returns:
Identifier for primary loss.
"""
raise NotImplementedError("Pick a key for primary loss that is used for dev_loss calculation")

class UnconditionedModel(TrainableModel):
"""
A template class for trainable model that computes target losses without conditioning on other inputs.
Expand All @@ -38,7 +33,7 @@ class UnconditionedModel(TrainableModel):
def __init__(self, trg_reader: input_readers.InputReader):
self.trg_reader = trg_reader

def calc_nll(self, trg: Union[batchers.Batch, sent.Sentence]) -> losses.FactoredLossExpr:
def calc_nll(self, trg: Union[batchers.Batch, sent.Sentence]) -> dy.Expression:
"""Calculate loss based on target inputs.
Losses are accumulated only across unmasked timesteps in each batch element.
Expand All @@ -64,8 +59,8 @@ def __init__(self, src_reader: input_readers.InputReader, trg_reader: input_read
self.src_reader = src_reader
self.trg_reader = trg_reader

def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence],
loss_calculator: loss_calculators.LossCalculator) -> losses.FactoredLossExpr:
def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \
-> dy.Expression:
"""Calculate loss based on input-output pairs.
Losses are accumulated only across unmasked timesteps in each batch element.
Expand Down
3 changes: 0 additions & 3 deletions xnmt/models/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ def generate(self, src, forced_trg_ids=None, normalize_scores=False):
score=score))
return outputs

def get_primary_loss(self):
return "mle"

def get_nobp_state(self, state):
output_state = state.rnn_state.output()
return output_state
3 changes: 0 additions & 3 deletions xnmt/models/language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ def __init__(self,
def shared_params(self):
return [{".src_embedder.emb_dim", ".encoder.input_dim"},]

def get_primary_loss(self):
return "mle"

def calc_nll(self, src, trg):
if not batchers.is_batched(src):
src = batchers.ListBatch([src])
Expand Down
3 changes: 0 additions & 3 deletions xnmt/models/sequence_labelers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ def __init__(self,
def shared_params(self):
return [{".src_embedder.emb_dim", ".encoder.input_dim"},]

def get_primary_loss(self):
return "mle"

def _encode_src(self, src):
event_trigger.start_sent(src)
embeddings = self.src_embedder.embed_sent(src)
Expand Down
7 changes: 2 additions & 5 deletions xnmt/models/translators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ def set_trg_vocab(self, trg_vocab=None):
"""
self.trg_vocab = trg_vocab

def get_primary_loss(self) -> str:
return "mle"

def get_nobp_state(self, state):
output_state = state.rnn_state.output()
if type(output_state) == EnsembleListDelegate:
Expand Down Expand Up @@ -358,7 +355,7 @@ def sentence_block_embed(self, embed, x, mask):
e = dy.reshape(e, (units, length), batch_size=batch)
return e

def calc_loss(self, src, trg, loss_cal=None, infer_prediction=False):
def calc_loss(self, src, trg, infer_prediction=False):
event_trigger.start_sent(src)
if not batchers.is_batched(src):
src = batchers.mark_as_batch([src])
Expand Down Expand Up @@ -491,7 +488,7 @@ def set_trg_vocab(self, trg_vocab=None):
def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) -> dy.Expression:
sub_losses = collections.defaultdict(list)
for model in self.models:
for loss_name, loss in model.calc_loss(src, trg).expr_factors.items():
for loss_name, loss in model.calc_nll(src, trg).expr_factors.items():
sub_losses[loss_name].append(loss)
model_loss = FactoredLossExpr()
for loss_name, losslist in sub_losses.items():
Expand Down

0 comments on commit 64b37a1

Please sign in to comment.