Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resetting of high scores, scheduler and optimizer for fine-tuning/domain adaptation #75

Merged
merged 2 commits into from Nov 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion configs/small.yaml
Expand Up @@ -23,7 +23,10 @@ testing: # specify which inference algorithm to use f
alpha: 1.0 # length penalty for beam search

training: # specify training details here
#load_model: "my_model/50.ckpt" # if given, load a pre-trained model from this checkpoint
#load_model: "models/small_model/60.ckpt" # if given, load a pre-trained model from this checkpoint
reset_best_ckpt: False # if True, reset the tracking of the best checkpoint and scores. Use for domain adaptation or fine-tuning with new metrics or dev data.
reset_scheduler: False # if True, overwrite scheduler in loaded checkpoint with parameters specified in this config. Use for domain adaptation or fine-tuning.
reset_optimizer: False # if True, overwrite optimizer in loaded checkpoint with parameters specified in this config. Use for domain adaptation or fine-tuning.
random_seed: 42 # set this seed to make training deterministic
optimizer: "adam" # choices: "sgd", "adam", "adadelta", "adagrad", "rmsprop", default is SGD
adam_betas: [0.9, 0.999] # beta parameters for Adam. These are the defaults. Typically these are different for Transformer models.
Expand Down
5 changes: 4 additions & 1 deletion configs/transformer_small.yaml
Expand Up @@ -21,7 +21,10 @@ testing: # specify which inference algorithm to use f
alpha: 1.0 # length penalty for beam search

training: # specify training details here
#load_model: "my_model/50.ckpt" # if given, load a pre-trained model from this checkpoint
#load_model: "models/transformer/60.ckpt" # if given, load a pre-trained model from this checkpoint
reset_best_ckpt: False # if True, reset the tracking of the best checkpoint and scores. Use for domain adaptation or fine-tuning with new metrics or dev data.
reset_scheduler: False # if True, overwrite scheduler in loaded checkpoint with parameters specified in this config. Use for domain adaptation or fine-tuning.
reset_optimizer: False # if True, overwrite optimizer in loaded checkpoint with parameters specified in this config. Use for domain adaptation or fine-tuning.
random_seed: 42 # set this seed to make training deterministic
optimizer: "adam" # choices: "sgd", "adam", "adadelta", "adagrad", "rmsprop", default is SGD
adam_betas: [0.9, 0.98] # beta parameters for Adam. These are the defaults. Typically these are different for Transformer models.
Expand Down
5 changes: 4 additions & 1 deletion docs/source/faq.rst
Expand Up @@ -38,12 +38,15 @@ Training
Depends on the size of your data. For most use-cases you want to validate at least once per epoch.
Say you have 100k training examples and train with mini-batches of size 20, then you should set ``validation_freq`` to 5000 (100k/20) to validate once per epoch.

- **How can I perform domain adaptation?**
- **How can I perform domain adaptation or fine-tuning?**
Both approaches are similar, so we call the fine-tuning data *in-domain* data in the following.
1. First train your model on one dataset (the *out-of-domain* data).
2. Modify the original configuration file (or better a copy of it) in the data section to point to the new *in-domain* data.
Specify which vocabularies to use: ``src_vocab: out-of-domain-model/src_vocab.txt`` and likewise for ``trg_vocab``.
You have to specify this, otherwise JoeyNMT will try to build a new vocabulary from the new in-domain data, which the out-of-domain model wasn't built with.
In the training section, specify which checkpoint of the out-of-domain model you want to start adapting: ``load_model: out-of-domain-model/best.ckpt``.
If you set ``reset_best_ckpt'': True'', previously stored high scores under your metric will be ignored, and if you set ``reset_scheduler'' and ``reset_optimizer'' you can also overwrite the stored scheduler and optimizer with the new ones in your configuration.
Use this if the scores on your new dev set are lower than on the old dev set, or if you use a different metric or schedule for fine-tuning.
3. Train the in-domain model.

- **What if training is interrupted and I need to resume it?**
Expand Down
43 changes: 35 additions & 8 deletions joeynmt/training.py
Expand Up @@ -153,7 +153,13 @@ def __init__(self, model: Model, config: dict) -> None:
if "load_model" in train_config.keys():
model_load_path = train_config["load_model"]
self.logger.info("Loading model from %s", model_load_path)
self.init_from_checkpoint(model_load_path)
reset_best_ckpt = train_config.get("reset_best_ckpt", False)
reset_scheduler = train_config.get("reset_scheduler", False)
reset_optimizer = train_config.get("reset_optimizer", False)
self.init_from_checkpoint(model_load_path,
reset_best_ckpt=reset_best_ckpt,
reset_scheduler=reset_scheduler,
reset_optimizer=reset_optimizer)

def _save_checkpoint(self) -> None:
"""
Expand Down Expand Up @@ -196,31 +202,52 @@ def _save_checkpoint(self) -> None:
# overwrite best.ckpt
torch.save(state, best_path)

def init_from_checkpoint(self, path: str) -> None:
def init_from_checkpoint(self, path: str,
reset_best_ckpt: bool = False,
reset_scheduler: bool = False,
reset_optimizer: bool = False) -> None:
"""
Initialize the trainer from a given checkpoint file.

This checkpoint file contains not only model parameters, but also
scheduler and optimizer states, see `self._save_checkpoint`.

:param path: path to checkpoint
:param reset_best_ckpt: reset tracking of the best checkpoint,
use for domain adaptation with a new dev
set or when using a new metric for fine-tuning.
:param reset_scheduler: reset the learning rate scheduler, and do not
use the one stored in the checkpoint.
:param reset_optimizer: reset the optimizer, and do not use the one
stored in the checkpoint.
"""
model_checkpoint = load_checkpoint(path=path, use_cuda=self.use_cuda)

# restore model and optimizer parameters
self.model.load_state_dict(model_checkpoint["model_state"])

self.optimizer.load_state_dict(model_checkpoint["optimizer_state"])
if not reset_optimizer:
self.optimizer.load_state_dict(model_checkpoint["optimizer_state"])
else:
self.logger.info("Reset optimizer.")

if model_checkpoint["scheduler_state"] is not None and \
self.scheduler is not None:
self.scheduler.load_state_dict(model_checkpoint["scheduler_state"])
if not reset_scheduler:
if model_checkpoint["scheduler_state"] is not None and \
self.scheduler is not None:
self.scheduler.load_state_dict(
model_checkpoint["scheduler_state"])
else:
self.logger.info("Reset scheduler.")

# restore counts
self.steps = model_checkpoint["steps"]
self.total_tokens = model_checkpoint["total_tokens"]
self.best_ckpt_score = model_checkpoint["best_ckpt_score"]
self.best_ckpt_iteration = model_checkpoint["best_ckpt_iteration"]

if not reset_best_ckpt:
self.best_ckpt_score = model_checkpoint["best_ckpt_score"]
self.best_ckpt_iteration = model_checkpoint["best_ckpt_iteration"]
else:
self.logger.info("Reset tracking of the best checkpoint.")

# move parameters to cuda
if self.use_cuda:
Expand Down