Skip to content

Commit

Permalink
Cleanup events and training task (#500)
Browse files Browse the repository at this point in the history
* interface clean up

* clean up events

* revert rename

* fix unit tests
  • Loading branch information
msperber committed Aug 3, 2018
1 parent c0e30a4 commit e06711b
Show file tree
Hide file tree
Showing 19 changed files with 232 additions and 238 deletions.
10 changes: 5 additions & 5 deletions test/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dynet as dy

from xnmt.modelparts.attenders import MlpAttender
from xnmt import batchers, events
from xnmt import batchers, event_trigger, events
from xnmt.modelparts.bridges import CopyBridge
from xnmt.modelparts.decoders import AutoRegressiveDecoder
from xnmt.modelparts.embedders import SimpleWordEmbedder
Expand Down Expand Up @@ -44,7 +44,7 @@ def setUp(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
event_trigger.set_train(False)

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
Expand Down Expand Up @@ -81,7 +81,7 @@ def setUp(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
event_trigger.set_train(False)

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
Expand Down Expand Up @@ -117,7 +117,7 @@ def setUp(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
event_trigger.set_train(False)

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
Expand Down Expand Up @@ -156,7 +156,7 @@ def setUp(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
event_trigger.set_train(False)

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
Expand Down
8 changes: 4 additions & 4 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dynet as dy

from xnmt.modelparts.attenders import MlpAttender
from xnmt import batchers, events
from xnmt import batchers, event_trigger, events
from xnmt.modelparts.bridges import CopyBridge
from xnmt.modelparts.decoders import AutoRegressiveDecoder
from xnmt.modelparts.embedders import SimpleWordEmbedder
Expand Down Expand Up @@ -44,7 +44,7 @@ def setUp(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
event_trigger.set_train(False)

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
Expand Down Expand Up @@ -83,7 +83,7 @@ def setUp(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
event_trigger.set_train(False)

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
Expand Down Expand Up @@ -120,7 +120,7 @@ def setUp(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
event_trigger.set_train(False)

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
Expand Down
21 changes: 7 additions & 14 deletions test/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from xnmt.modelparts.transforms import NonLinear
from xnmt.models.translators import DefaultTranslator
from xnmt.vocabs import Vocab
from xnmt import batchers, events
from xnmt import batchers, event_trigger, events

class TestEncoder(unittest.TestCase):

Expand All @@ -32,18 +32,11 @@ def setUp(self):
self.src_data = list(self.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.trg_reader.read_sents("examples/data/head.en"))

@events.register_xnmt_event
def set_train(self, val):
pass
@events.register_xnmt_event
def start_sent(self, src):
pass

def assert_in_out_len_equal(self, model):
dy.renew_cg()
self.set_train(True)
event_trigger.set_train(True)
src = self.src_data[0]
self.start_sent(src)
event_trigger.start_sent(src)
embeddings = model.src_embedder.embed_sent(src)
encodings = model.encoder.transduce(embeddings)
self.assertEqual(len(embeddings), len(encodings))
Expand Down Expand Up @@ -119,11 +112,11 @@ def test_py_lstm_encoder_len(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.set_train(True)
event_trigger.set_train(True)
for sent_i in range(10):
dy.renew_cg()
src = self.src_data[sent_i].create_padded_sent(4 - (self.src_data[sent_i].sent_len() % 4))
self.start_sent(src)
event_trigger.start_sent(src)
embeddings = model.src_embedder.embed_sent(src)
encodings = model.encoder.transduce(embeddings)
self.assertEqual(int(math.ceil(len(embeddings) / float(4))), len(encodings))
Expand All @@ -149,11 +142,11 @@ def test_py_lstm_mask(self):
train_src, _ = \
batcher.pack(self.src_data, self.trg_data)

self.set_train(True)
event_trigger.set_train(True)
for sent_i in range(3):
dy.renew_cg()
src = train_src[sent_i]
self.start_sent(src)
event_trigger.start_sent(src)
embeddings = model.src_embedder.embed_sent(src)
encodings = model.encoder.transduce(embeddings)
if train_src[sent_i].mask is None:
Expand Down
28 changes: 14 additions & 14 deletions test/test_segmenting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from xnmt.modelparts.decoders import AutoRegressiveDecoder
from xnmt.modelparts.embedders import SimpleWordEmbedder
import xnmt.events
from xnmt import batchers
from xnmt import batchers, event_trigger
from xnmt.input_readers import PlainTextReader
from xnmt.input_readers import CharFromWordTextReader
from xnmt.transducers.recurrent import UniLSTMSeqTransducer
Expand Down Expand Up @@ -83,7 +83,7 @@ def setUp(self):
trg_embed_dim=layer_dim,
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(True)
event_trigger.set_train(True)

self.layer_dim = layer_dim
self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
Expand All @@ -96,7 +96,7 @@ def test_reinforce_loss(self):
fertility_loss = GlobalFertilityLoss()
mle_loss = MLELoss()
loss = CompositeLoss(losses=[mle_loss, fertility_loss]).calc_loss(self.model, self.src[0], self.trg[0])
reinforce_loss = self.model.calc_additional_loss(self.trg[0], self.model, loss)
reinforce_loss = event_trigger.calc_additional_loss(self.trg[0], self.model, loss)
pl = self.model.encoder.policy_learning
# Ensure correct length
src = self.src[0]
Expand All @@ -117,7 +117,7 @@ def test_reinforce_loss(self):

def calc_loss_single_batch(self):
loss = MLELoss().calc_loss(self.model, self.src[0], self.trg[0])
reinforce_loss = self.model.calc_additional_loss(self.trg[0], self.model, loss)
reinforce_loss = event_trigger.calc_additional_loss(self.trg[0], self.model, loss)
return loss, reinforce_loss

def test_gold_input(self):
Expand All @@ -134,24 +134,24 @@ def test_sample_input(self):
self.assertEqual(self.model.encoder.policy_learning.sampling_action, PolicyGradient.SamplingAction.PREDEFINED)

def test_policy_train_test(self):
self.model.set_train(True)
event_trigger.set_train(True)
self.calc_loss_single_batch()
self.assertEqual(self.model.encoder.policy_learning.sampling_action, PolicyGradient.SamplingAction.POLICY_CLP)
self.model.set_train(False)
event_trigger.set_train(False)
self.calc_loss_single_batch()
self.assertEqual(self.model.encoder.policy_learning.sampling_action, PolicyGradient.SamplingAction.POLICY_AMAX)

def test_no_policy_train_test(self):
self.model.encoder.policy_learning = None
self.model.set_train(True)
event_trigger.set_train(True)
self.calc_loss_single_batch()
self.assertEqual(self.model.encoder.segmenting_action, SegmentingSeqTransducer.SegmentingAction.PURE_SAMPLE)
self.model.set_train(False)
event_trigger.set_train(False)
self.calc_loss_single_batch()
self.assertEqual(self.model.encoder.segmenting_action, SegmentingSeqTransducer.SegmentingAction.PURE_SAMPLE)

def test_sample_during_search(self):
self.model.set_train(False)
event_trigger.set_train(False)
self.model.encoder.sample_during_search = True
self.calc_loss_single_batch()
self.assertEqual(self.model.encoder.segmenting_action, SegmentingSeqTransducer.SegmentingAction.POLICY)
Expand Down Expand Up @@ -199,7 +199,7 @@ def setUp(self):
trg_embed_dim=layer_dim,
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(True)
event_trigger.set_train(True)

self.layer_dim = layer_dim
self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
Expand All @@ -209,7 +209,7 @@ def setUp(self):
dy.renew_cg(immediate_compute=True, check_validity=True)

def inp_emb(self, idx=0):
self.model.start_sent(self.src[idx])
event_trigger.start_sent(self.src[idx])
embed = self.model.src_embedder.embed_sent(self.src[idx])
return embed

Expand Down Expand Up @@ -268,19 +268,19 @@ def test_convolution_composer(self):
enc.segment_composer = ConvolutionComposer(ngram_size=1,
embed_dim=self.layer_dim,
hidden_dim=self.layer_dim)
self.model.set_train(True)
event_trigger.set_train(True)
enc.transduce(self.inp_emb(0))
enc.segment_composer = ConvolutionComposer(ngram_size=3,
embed_dim=self.layer_dim,
hidden_dim=self.layer_dim)
self.model.set_train(True)
event_trigger.set_train(True)
enc.transduce(self.inp_emb(0))

def test_transducer_composer(self):
enc = self.segmenting_encoder
enc.segment_composer = SeqTransducerComposer(seq_transducer=BiLSTMSeqTransducer(input_dim=self.layer_dim,
hidden_dim=self.layer_dim))
self.model.set_train(True)
event_trigger.set_train(True)
enc.transduce(self.inp_emb(0))

if __name__ == "__main__":
Expand Down
16 changes: 8 additions & 8 deletions test/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from xnmt.models.translators import DefaultTranslator
from xnmt.modelparts.scorers import Softmax
from xnmt.vocabs import Vocab
from xnmt import sent
from xnmt import event_trigger, sent

class TestTruncatedBatchTraining(unittest.TestCase):

Expand Down Expand Up @@ -91,7 +91,7 @@ def test_loss_model1(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
model.set_train(False)
event_trigger.set_train(False)
self.assert_single_loss_equals_batch_loss(model)

def test_loss_model2(self):
Expand All @@ -113,7 +113,7 @@ def test_loss_model2(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
model.set_train(False)
event_trigger.set_train(False)
self.assert_single_loss_equals_batch_loss(model, pad_src_to_multiple=4)

def test_loss_model3(self):
Expand All @@ -135,7 +135,7 @@ def test_loss_model3(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
model.set_train(False)
event_trigger.set_train(False)
self.assert_single_loss_equals_batch_loss(model)

def test_loss_model4(self):
Expand All @@ -157,7 +157,7 @@ def test_loss_model4(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
model.set_train(False)
event_trigger.set_train(False)
self.assert_single_loss_equals_batch_loss(model)

class TestBatchTraining(unittest.TestCase):
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_loss_model1(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
model.set_train(False)
event_trigger.set_train(False)
self.assert_single_loss_equals_batch_loss(model)

def test_loss_model2(self):
Expand All @@ -254,7 +254,7 @@ def test_loss_model2(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
model.set_train(False)
event_trigger.set_train(False)
self.assert_single_loss_equals_batch_loss(model, pad_src_to_multiple=4)

def test_loss_model3(self):
Expand All @@ -276,7 +276,7 @@ def test_loss_model3(self):
scorer=Softmax(input_dim=layer_dim, vocab_size=100),
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
model.set_train(False)
event_trigger.set_train(False)
self.assert_single_loss_equals_batch_loss(model)


Expand Down
8 changes: 4 additions & 4 deletions xnmt/eval/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from xnmt.eval.metrics import LossScore
from xnmt.losses import FactoredLossExpr, FactoredLossVal
import xnmt.xnmt_evaluate
from xnmt import events, reports, utils
from xnmt import event_trigger, events, reports, utils

class EvalTask(object):
"""
Expand Down Expand Up @@ -62,7 +62,7 @@ def eval(self) -> 'EvalScore':
Returns:
Evaluated score
"""
self.model.set_train(False)
event_trigger.set_train(False)
if self.src_data is None:
self.src_data, self.ref_data, self.src_batches, self.ref_batches = \
input_readers.read_parallel_corpus(src_reader=self.model.src_reader,
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(self, src_file: Union[str,Sequence[str]], ref_file: Union[str,Seque
self.desc=desc

def eval(self):
self.model.set_train(False)
event_trigger.set_train(False)
self.report_corpus_info({"ref_file":self.ref_file})
self.inference.perform_inference(generator=self.model,
src_file=self.src_file,
Expand Down Expand Up @@ -160,7 +160,7 @@ def __init__(self, src_file: Union[str,Sequence[str]], hyp_file: str, model: 'mo
self.inference = inference or self.model.inference

def eval(self):
self.model.set_train(False)
event_trigger.set_train(False)
self.inference.perform_inference(generator=self.model,
src_file=self.src_file,
trg_file=self.hyp_file)
Expand Down

0 comments on commit e06711b

Please sign in to comment.