Skip to content

Commit

Permalink
Split off function to find latest checkpoint in Trainer (allenai#1414)
Browse files Browse the repository at this point in the history
Quick refactor to expose the functionality to find the latest checkpoint in the serialization directory in `Trainer`.
  • Loading branch information
matt-peters authored and DeNeutoy committed Jun 23, 2018
1 parent 11041f4 commit b75786c
Showing 1 changed file with 32 additions and 18 deletions.
50 changes: 32 additions & 18 deletions allennlp/training/trainer.py
Expand Up @@ -812,30 +812,16 @@ def _save_checkpoint(self,
for fname in paths_to_remove[1:]:
os.remove(fname)

def _restore_checkpoint(self) -> Tuple[int, List[float]]:
def find_latest_checkpoint(self) -> Tuple[str, str]:
"""
Restores a model from a serialization_dir to the last saved checkpoint.
This includes an epoch count and optimizer state, which is serialized separately
from model parameters. This function should only be used to continue training -
if you wish to load a model for inference/load parts of a model into a new
computation graph, you should use the native Pytorch functions:
`` model.load_state_dict(torch.load("/path/to/model/weights.th"))``
If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
this function will do nothing and return 0.
Returns
-------
epoch: int
The epoch at which to resume training, which should be one after the epoch
in the saved training state.
Return the location of the latest model and training state files.
If there isn't a valid checkpoint then return None.
"""
have_checkpoint = (self._serialization_dir is not None and
any("model_state_epoch_" in x for x in os.listdir(self._serialization_dir)))

if not have_checkpoint:
# No checkpoint to restore, start at 0
return 0, []
return None

serialization_files = os.listdir(self._serialization_dir)
model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x]
Expand Down Expand Up @@ -867,6 +853,34 @@ def _restore_checkpoint(self) -> Tuple[int, List[float]]:
training_state_path = os.path.join(self._serialization_dir,
"training_state_epoch_{}.th".format(epoch_to_load))

return (model_path, training_state_path)

def _restore_checkpoint(self) -> Tuple[int, List[float]]:
"""
Restores a model from a serialization_dir to the last saved checkpoint.
This includes an epoch count and optimizer state, which is serialized separately
from model parameters. This function should only be used to continue training -
if you wish to load a model for inference/load parts of a model into a new
computation graph, you should use the native Pytorch functions:
`` model.load_state_dict(torch.load("/path/to/model/weights.th"))``
If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
this function will do nothing and return 0.
Returns
-------
epoch: int
The epoch at which to resume training, which should be one after the epoch
in the saved training state.
"""
latest_checkpoint = self.find_latest_checkpoint()

if latest_checkpoint is None:
# No checkpoint to restore, start at 0
return 0, []

model_path, training_state_path = latest_checkpoint

# Load the parameters onto CPU, then transfer to GPU.
# This avoids potential OOM on GPU for large models that
# load parameters onto GPU then make a new GPU copy into the parameter
Expand Down

0 comments on commit b75786c

Please sign in to comment.