-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
checkpoints.py
38 lines (32 loc) · 1.29 KB
/
checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import logging
from typing import Any, Dict
from flair.trainers.plugins.base import TrainerPlugin
log = logging.getLogger("flair")
class CheckpointPlugin(TrainerPlugin):
def __init__(
self,
save_model_each_k_epochs,
save_optimizer_state,
base_path,
) -> None:
super().__init__()
self.save_optimizer_state = save_optimizer_state
self.save_model_each_k_epochs = save_model_each_k_epochs
self.base_path = base_path
@TrainerPlugin.hook
def after_training_epoch(self, epoch, **kw):
"""Saves the model each k epochs."""
if self.save_model_each_k_epochs > 0 and epoch % self.save_model_each_k_epochs == 0:
log.info(
f"Saving model at current epoch since 'save_model_each_k_epochs={self.save_model_each_k_epochs}' "
f"was set"
)
model_name = "model_epoch_" + str(epoch) + ".pt"
self.model.save(self.base_path / model_name, checkpoint=self.save_optimizer_state)
def get_state(self) -> Dict[str, Any]:
return {
**super().get_state(),
"base_path": str(self.base_path),
"save_model_each_k_epochs": self.save_model_each_k_epochs,
"save_optimizer_state": self.save_optimizer_state,
}