Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save train time in checkpoint #169

Merged
merged 4 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/trainer/test_default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ def _get_kwargs(self, with_roi=False, compile_model=False):

def test_fit(self):
from torch_em.trainer import DefaultTrainer

trainer = DefaultTrainer(**self._get_kwargs())
trainer.fit(10)
train_time = trainer.train_time
self.assertGreater(train_time, 0.0)

save_folder = os.path.join(self.checkpoint_folder, self.name)
self.assertTrue(os.path.exists(save_folder))
Expand All @@ -69,13 +72,15 @@ def test_fit(self):

trainer.fit(2)
self.assertEqual(trainer.iteration, 12)
self.assertGreater(trainer.train_time, train_time)

trainer = DefaultTrainer(**self._get_kwargs())
trainer.fit(8, load_from_checkpoint="latest")
self.assertEqual(trainer.iteration, 20)

def test_from_checkpoint(self):
from torch_em.trainer import DefaultTrainer

trainer = DefaultTrainer(**self._get_kwargs(with_roi=True))
trainer.fit(10)
exp_model = trainer.model
Expand All @@ -86,6 +91,7 @@ def test_from_checkpoint(self):
name="latest"
)
self.assertEqual(trainer.iteration, trainer2.iteration)
self.assertEqual(trainer.train_time, trainer2.train_time)
self.assertEqual(trainer2.train_loader.dataset.raw.shape, exp_data_shape)
self.assertTrue(torch_em.util.model_is_equal(exp_model, trainer2.model))

Expand All @@ -100,6 +106,7 @@ def test_from_checkpoint(self):
@unittest.skipIf(sys.version_info.minor > 10, "Not supported for python > 3.10")
def test_compiled_model(self):
from torch_em.trainer import DefaultTrainer

trainer = DefaultTrainer(**self._get_kwargs(compile_model=True))
trainer.fit(10)
exp_model = trainer.model
Expand Down
3 changes: 2 additions & 1 deletion torch_em/self_training/fix_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
# functionality for saving checkpoints and initialization
#

def save_checkpoint(self, name, best_metric):
def save_checkpoint(self, name, best_metric, **extra_save_dict):
train_loader_kwargs = get_constructor_arguments(self.train_loader)
val_loader_kwargs = get_constructor_arguments(self.val_loader)
extra_state = {
Expand All @@ -151,6 +151,7 @@ def save_checkpoint(self, name, best_metric):
"metric_kwargs": {},
},
}
extra_state.update(**extra_save_dict)
super().save_checkpoint(name, best_metric, **extra_state)

# distribution alignment - encourages the distribution of the model's generated pseudo labels to match the marginal
Expand Down
3 changes: 2 additions & 1 deletion torch_em/self_training/mean_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _momentum_update(self):
# functionality for saving checkpoints and initialization
#

def save_checkpoint(self, name, best_metric):
def save_checkpoint(self, name, best_metric, **extra_save_dict):
train_loader_kwargs = get_constructor_arguments(self.train_loader)
val_loader_kwargs = get_constructor_arguments(self.val_loader)
extra_state = {
Expand All @@ -187,6 +187,7 @@ def save_checkpoint(self, name, best_metric):
"metric_kwargs": {},
},
}
extra_state.update(**extra_save_dict)
super().save_checkpoint(name, best_metric, **extra_state)

def load_checkpoint(self, checkpoint="best"):
Expand Down
18 changes: 14 additions & 4 deletions torch_em/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(

self.mixed_precision = mixed_precision
self.early_stopping = early_stopping
self.train_time = 0.0

self.scaler = amp.GradScaler() if mixed_precision else None

Expand Down Expand Up @@ -457,7 +458,7 @@ def _initialize(self, iterations, load_from_checkpoint, epochs=None):
best_metric = np.inf
return best_metric

def save_checkpoint(self, name, best_metric, **extra_save_dict):
def save_checkpoint(self, name, best_metric, train_time=0.0, **extra_save_dict):
save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
extra_init_dict = extra_save_dict.pop("init", {})
save_dict = {
Expand All @@ -468,6 +469,7 @@ def save_checkpoint(self, name, best_metric, **extra_save_dict):
"model_state": self.model.state_dict(),
"optimizer_state": self.optimizer.state_dict(),
"init": self.init_data | extra_init_dict,
"train_time": train_time,
}
save_dict.update(**extra_save_dict)
if self.scaler is not None:
Expand All @@ -492,6 +494,7 @@ def load_checkpoint(self, checkpoint="best"):
self._epoch = save_dict["epoch"]
self._best_epoch = save_dict["best_epoch"]
self.best_metric = save_dict["best_metric"]
self.train_time = save_dict.get("train_time", 0.0)

model_state = save_dict["model_state"]
# to enable loading compiled models
Expand Down Expand Up @@ -549,6 +552,7 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"

train_epochs = self.max_epoch - self._epoch
t_start = time.time()
for _ in range(train_epochs):

# run training and validation for this epoch
Expand All @@ -561,19 +565,22 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
if self.lr_scheduler is not None:
self.lr_scheduler.step(current_metric)

# how long did we train in total?
total_train_time = (time.time() - t_start) + self.train_time

# save this checkpoint as the new best checkpoint if
# it has the best overall validation metric
if current_metric < best_metric:
best_metric = current_metric
self._best_epoch = self._epoch
self.save_checkpoint("best", best_metric)
self.save_checkpoint("best", best_metric, train_time=total_train_time)

# save this checkpoint as the latest checkpoint
self.save_checkpoint("latest", best_metric)
self.save_checkpoint("latest", best_metric, train_time=total_train_time)

# if we save after every k-th epoch then check if we need to save now
if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
self.save_checkpoint(f"epoch-{self._epoch + 1}", best_metric)
self.save_checkpoint(f"epoch-{self._epoch + 1}", best_metric, train_time=total_train_time)

# if early stopping has been specified then check if the stopping condition is met
if self.early_stopping is not None:
Expand All @@ -591,6 +598,9 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
if self._generate_name:
self.name = None

# Update the train time
self.train_time = total_train_time

# TODO save the model to wandb if we have the wandb logger
if isinstance(self.logger, WandbLogger):
self.logger.get_wandb().finish()
Expand Down
5 changes: 2 additions & 3 deletions torch_em/trainer/spoco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ def _momentum_update(self):
for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()):
param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum)

def save_checkpoint(self, name, best_metric):
model2_state = {"model2_state": self.model2.state_dict()}
super().save_checkpoint(name, best_metric, **model2_state)
def save_checkpoint(self, name, best_metric, **extra_save_dict):
super().save_checkpoint(name, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict)

def load_checkpoint(self, checkpoint="best"):
save_dict = super().load_checkpoint(checkpoint)
Expand Down