Skip to content

Commit

Permalink
per_task_backward option for SameBatchMultiTaskTrainingRegimen (#410)
Browse files Browse the repository at this point in the history
* split update_weights() into backward() and update()

* implement per_task_backward for SameBatchMultiTaskTrainingRegimen

* renew CG
  • Loading branch information
msperber authored and neubig committed May 30, 2018
1 parent 8988979 commit 08ea387
Showing 1 changed file with 49 additions and 25 deletions.
74 changes: 49 additions & 25 deletions xnmt/training_regimen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import argparse
from typing import Sequence
from collections import OrderedDict

from xnmt.settings import settings
Expand All @@ -9,37 +11,44 @@
from xnmt.loss_calculator import MLELoss
from xnmt.param_collection import ParamManager
from xnmt.persistence import serializable_init, Serializable, bare, Ref
import xnmt.optimizer
from xnmt.training_task import SimpleTrainingTask
from xnmt import training_task, optimizer, batcher

class TrainingRegimen(object):
"""
A training regimen is a class that implements a training loop.
"""
def run_training(self, save_fct, update_weights=True):
"""
Runs training steps in a loop until stopping criterion is reached.
Run training steps in a loop until stopping criterion is reached.
Args:
save_fct: function to be invoked to save a model at dev checkpoints
update_weights (bool): Whether parameters should be updated
"""
raise NotImplementedError("")
def update_weights(self, loss, trainer, dynet_profiling):

def backward(self, loss: dy.Expression, dynet_profiling: int) -> None:
"""
Standardized way to perform backward pass and parameter updates.
Perform backward pass to accumulate gradients.
Args:
loss: Result of self.training_step(...)
trainer (XnmtOptimizer): DyNet trainer
dynet_profiling (int): if > 0, print the computation graph
dynet_profiling: if > 0, print the computation graph
"""
if dynet_profiling and dynet_profiling > 0:
dy.print_text_graphviz()
loss.backward()

def update(self, trainer: optimizer.XnmtOptimizer) -> None:
"""
Update DyNet weights using the given optimizer.
Args:
trainer: DyNet trainer
"""
trainer.update()

class SimpleTrainingRegimen(SimpleTrainingTask, TrainingRegimen, Serializable):
class SimpleTrainingRegimen(training_task.SimpleTrainingTask, TrainingRegimen, Serializable):
"""
Args:
model (TrainableModel): the model
Expand Down Expand Up @@ -76,7 +85,7 @@ class SimpleTrainingRegimen(SimpleTrainingTask, TrainingRegimen, Serializable):

@serializable_init
def __init__(self, model=Ref("model"), src_file=None, trg_file=None, dev_every=0, dev_zero=False,
batcher=bare(xnmt.batcher.SrcBatcher, batch_size=32), loss_calculator=bare(MLELoss), trainer=None,
batcher=bare(batcher.SrcBatcher, batch_size=32), loss_calculator=bare(MLELoss), trainer=None,
run_for_epochs=None, lr_decay=1.0, lr_decay_times=3, patience=1, initial_patience=None, dev_tasks=None,
dev_combinator=None, restart_trainer: bool = False,
reload_command=None, name="{EXP}", sample_train_sents=None,
Expand Down Expand Up @@ -104,7 +113,7 @@ def __init__(self, model=Ref("model"), src_file=None, trg_file=None, dev_every=0
max_src_len=max_src_len,
max_trg_len=max_trg_len)
self.dev_zero = dev_zero
self.trainer = trainer or xnmt.optimizer.SimpleSGDTrainer(e0=0.1)
self.trainer = trainer or optimizer.SimpleSGDTrainer(e0=0.1)
self.dynet_profiling = getattr(commandline_args, "dynet_profiling", 0) if commandline_args else 0
self.train_loss_tracker = TrainLossTracker(self)

Expand All @@ -122,7 +131,9 @@ def run_training(self, save_fct, update_weights=True):
self.model.set_train(True)
loss_builder = self.training_step(src, trg)
loss = loss_builder.compute()
if update_weights: self.update_weights(loss, self.trainer, self.dynet_profiling)
if update_weights:
self.backward(loss, self.dynet_profiling)
self.update(self.trainer)
self.train_loss_tracker.report(trg, loss_builder.get_loss_stats())
if self.checkpoint_needed():
self.checkpoint_and_save(save_fct)
Expand Down Expand Up @@ -156,7 +167,7 @@ def __init__(self,
self.dynet_profiling = getattr(commandline_args, "dynet_profiling", 0) if commandline_args else 0
if len(tasks)==0: raise ValueError("Task list must be non-empty.")
self.tasks = tasks
self.trainer = trainer or xnmt.optimizer.SimpleSGDTrainer(e0=0.1)
self.trainer = trainer or optimizer.SimpleSGDTrainer(e0=0.1)
for task in tasks[1:]:
if hasattr(task, "trainer") and task.trainer is not None:
raise ValueError("Can instantiate only one trainer object. Possibly, multiple training regimens were created when training tasks should have been used.")
Expand All @@ -183,24 +194,29 @@ def trigger_train_event(self, value):

class SameBatchMultiTaskTrainingRegimen(MultiTaskTrainingRegimen, Serializable):
"""
Multi-task training where gradients are accumulated and weight updates
are thus performed jointly for each task. The relative weight between
tasks can be configured by setting each tasks batch size accordingly.
Multi-task training where gradients are accumulated and weight updates are thus performed jointly for each task.
The relative weight between tasks can be configured by setting each tasks batch size accordingly.
The stopping criterion of the first task is used (other tasks' stopping criteria are ignored).
Args:
tasks (List[TrainingTask]): training tasks
trainer (XnmtOptimizer): the trainer is shared across tasks
dev_zero (bool): if True, add a checkpoint before training loop is entered (useful with pretrained models).
commandline_args (Namespace):
tasks: training tasks
trainer: the trainer is shared across tasks
dev_zero: if True, add a checkpoint before training loop is entered (useful with pretrained models).
per_task_backward: if ``True``, call backward() for each task separately and renew computation graph between
tasks. Yields the same results, but ``True`` uses less memory while ``False`` may be
faster when using autobatching.
commandline_args:
"""
yaml_tag = "!SameBatchMultiTaskTrainingRegimen"

@serializable_init
def __init__(self, tasks, trainer=None, dev_zero=False,
commandline_args=Ref("exp_global.commandline_args", default=None)):
def __init__(self, tasks: Sequence[training_task.TrainingTask], trainer: optimizer.XnmtOptimizer = None,
dev_zero: bool = False, per_task_backward: bool = True,
commandline_args: argparse.Namespace = Ref("exp_global.commandline_args", default=None)):
super().__init__(tasks=tasks, trainer=trainer, dev_zero=dev_zero, commandline_args=commandline_args)
self.train_loss_trackers = {task : TrainLossTracker(task) for task in tasks}
self.per_task_backward = per_task_backward

def run_training(self, save_fct, update_weights=True):
task_generators = OrderedDict()
for task in self.tasks:
Expand All @@ -221,9 +237,15 @@ def run_training(self, save_fct, update_weights=True):
for task, src, trg in task_src_trg:
loss_builder = task.training_step(src, trg)
task_trg_loss_stats[task] = (trg, loss_builder.get_loss_stats())
task_losses.append(loss_builder.compute())
if self.per_task_backward:
self.backward(loss_builder.compute(), self.dynet_profiling)
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
else:
task_losses.append(loss_builder.compute())
if update_weights:
self.update_weights(sum(task_losses), self.trainer, self.dynet_profiling)
if not self.per_task_backward:
self.backward(sum(task_losses), self.dynet_profiling)
self.update(self.trainer)
for task, (trg, stats) in task_trg_loss_stats.items():
self.train_loss_trackers[task].report(trg, stats)
self.checkpoint_and_save(save_fct)
Expand Down Expand Up @@ -284,7 +306,8 @@ def run_training(self, save_fct, update_weights=True):
self.trigger_train_event(True)
loss_builder = cur_task.training_step(src, trg)
if update_weights:
self.update_weights(loss=loss_builder.compute(), trainer=self.trainer, dynet_profiling=self.dynet_profiling)
self.backward(loss=loss_builder.compute(), dynet_profiling=self.dynet_profiling)
self.update(trainer=self.trainer)
cur_train_loss_tracker.report(trg, loss_builder.get_loss_stats())
self.checkpoint_and_save(cur_task, cur_task_i, save_fct, dev_zero)
if self.tasks[0].should_stop_training(): break
Expand Down Expand Up @@ -336,7 +359,8 @@ def run_training(self, save_fct, update_weights=True):
loss_builder = cur_task.training_step(src, trg)
task_loss = loss_builder.compute()
if update_weights:
self.update_weights(task_loss, self.trainer, self.dynet_profiling)
self.backward(task_loss, self.dynet_profiling)
self.update(self.trainer)
cur_train_loss_tracker.report(trg, loss_builder.get_loss_stats())
self.checkpoint_and_save(cur_task, cur_task_id, save_fct, dev_zero)
if cur_task.should_stop_training(): break
Expand Down

0 comments on commit 08ea387

Please sign in to comment.