Skip to content

Commit

Permalink
Merge pull request #82 from neulab/train-test-behavior
Browse files Browse the repository at this point in the history
Train test behavior
  • Loading branch information
msperber committed Jun 3, 2017
2 parents 61fa023 + 857cbab commit d4c6661
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 5 deletions.
3 changes: 2 additions & 1 deletion xnmt/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from mlp import MLP
import inspect
from batcher import *
from translator import TrainTestInterface

class Decoder:
class Decoder(TrainTestInterface):
'''
A template class to convert a prefix of previously generated words and
a context vector into a probability distribution over possible next words.
Expand Down
8 changes: 6 additions & 2 deletions xnmt/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import pyramidal
import conv_encoder
from embedder import ExpressionSequence
from translator import TrainTestInterface

class Encoder:
class Encoder(TrainTestInterface):
"""
A parent class representing all classes that encode inputs.
"""
Expand Down Expand Up @@ -70,9 +71,12 @@ def use_params(self, encoder_spec, params, map_to_default_layer_dim):

class BiLSTMEncoder(BuilderEncoder):
def init_builder(self, encoder_spec, model):
params = self.use_params(encoder_spec, ["layers", "input_dim", "hidden_dim", model, dy.VanillaLSTMBuilder],
params = self.use_params(encoder_spec, ["layers", "input_dim", "hidden_dim", model, dy.VanillaLSTMBuilder, "dropout"],
map_to_default_layer_dim=["hidden_dim"])
self.dropout = params.pop()
self.builder = dy.BiRNNBuilder(*params)
def set_train(self, val):
self.builder.set_dropout(self.dropout if val else 0.0)

class ResidualLSTMEncoder(BuilderEncoder):
def init_builder(self, encoder_spec, model):
Expand Down
36 changes: 35 additions & 1 deletion xnmt/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,26 @@
from search_strategy import *
from vocab import Vocab

class Translator:
class TrainTestInterface:
"""
All subcomponents of the translator that behave differently at train and test time
should subclass this class.
"""
def set_train(self, val):
"""
Will be called with val=True when starting to train, and with val=False when starting
to evaluate.
:param val: bool that indicates whether we're in training mode
"""
pass
def get_train_test_components(self):
"""
:returns: list of subcomponents that implement TrainTestInterface and will be called recursively.
"""
return []


class Translator(TrainTestInterface):
'''
A template class implementing an end-to-end translator that can calculate a
loss and generate translations.
Expand All @@ -28,6 +47,16 @@ def translate(self, src):
'''
raise NotImplementedError('translate must be implemented for Translator subclasses')

def set_train(self, val):
for component in self.get_train_test_components():
Translator.set_train_recursive(component, val)
@staticmethod
def set_train_recursive(component, val):
component.set_train(val)
for sub_component in component.get_train_test_components():
Translator.set_train_recursive(sub_component, val)


class DefaultTranslator(Translator):
'''
A default translator based on attentional sequence-to-sequence models.
Expand All @@ -48,6 +77,9 @@ def __init__(self, input_embedder, encoder, attender, output_embedder, decoder):
self.output_embedder = output_embedder
self.decoder = decoder

def get_train_test_components(self):
return [self.encoder, self.decoder]

def calc_loss(self, src, trg):
embeddings = self.input_embedder.embed_sent(src)
encodings = self.encoder.transduce(embeddings)
Expand Down Expand Up @@ -96,3 +128,5 @@ def translate(self, src, search_strategy=None):
self.decoder.initialize()
output.append(search_strategy.generate_output(self.decoder, self.attender, self.output_embedder, src_length=len(sents)))
return output


1 change: 1 addition & 0 deletions xnmt/xnmt_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def xnmt_decode(args, model_elements=None):

# Perform decoding

translator.set_train(False)
with open(args.trg_file, 'wb') as fp: # Saving the translated output to a trg file
for src in src_corpus:
dy.renew_cg()
Expand Down
4 changes: 3 additions & 1 deletion xnmt/xnmt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def read_data(self):
def run_epoch(self):
self.logger.new_epoch()

self.translator.set_train(True)
for batch_num, (src, trg) in enumerate(zip(self.train_src, self.train_trg)):

# Loss calculation
Expand All @@ -201,7 +202,7 @@ def run_epoch(self):

# Devel reporting
if self.logger.report_train_process():

self.translator.set_train(False)
self.logger.new_dev()
for src, trg in zip(self.dev_src, self.dev_trg):
dy.renew_cg()
Expand All @@ -221,6 +222,7 @@ def run_epoch(self):
self.early_stopping_reached = True

self.trainer.update_epoch()
self.translator.set_train(True)

return math.exp(self.logger.epoch_loss / self.logger.epoch_words), \
math.exp(self.logger.dev_loss / self.logger.dev_words)
Expand Down

0 comments on commit d4c6661

Please sign in to comment.