From de158a1e40a5ccde8d246dd65c2cd098eec72216 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 29 Feb 2024 21:18:30 +0100 Subject: [PATCH 1/2] Add current_metric while saving checkpoints --- torch_em/trainer/default_trainer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torch_em/trainer/default_trainer.py b/torch_em/trainer/default_trainer.py index 36ded336..796dfcd2 100644 --- a/torch_em/trainer/default_trainer.py +++ b/torch_em/trainer/default_trainer.py @@ -458,7 +458,7 @@ def _initialize(self, iterations, load_from_checkpoint, epochs=None): best_metric = np.inf return best_metric - def save_checkpoint(self, name, best_metric, train_time=0.0, **extra_save_dict): + def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict): save_path = os.path.join(self.checkpoint_folder, f"{name}.pt") extra_init_dict = extra_save_dict.pop("init", {}) save_dict = { @@ -466,6 +466,7 @@ def save_checkpoint(self, name, best_metric, train_time=0.0, **extra_save_dict): "epoch": self._epoch, "best_epoch": self._best_epoch, "best_metric": best_metric, + "current_metric": current_metric, "model_state": self.model.state_dict(), "optimizer_state": self.optimizer.state_dict(), "init": self.init_data | extra_init_dict, @@ -494,6 +495,7 @@ def load_checkpoint(self, checkpoint="best"): self._epoch = save_dict["epoch"] self._best_epoch = save_dict["best_epoch"] self.best_metric = save_dict["best_metric"] + self.current_metric = save_dict["current_metric"] self.train_time = save_dict.get("train_time", 0.0) model_state = save_dict["model_state"] @@ -573,14 +575,16 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever if current_metric < best_metric: best_metric = current_metric self._best_epoch = self._epoch - self.save_checkpoint("best", best_metric, train_time=total_train_time) + self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time) # save this checkpoint as the latest checkpoint - self.save_checkpoint("latest", best_metric, train_time=total_train_time) + self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time) # if we save after every k-th epoch then check if we need to save now if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0: - self.save_checkpoint(f"epoch-{self._epoch + 1}", best_metric, train_time=total_train_time) + self.save_checkpoint( + f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time + ) # if early stopping has been specified then check if the stopping condition is met if self.early_stopping is not None: From 34b0a0a5600eef348f49889fcc6689d0388b3733 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 29 Feb 2024 21:28:27 +0100 Subject: [PATCH 2/2] Update other trainers --- torch_em/self_training/fix_match.py | 4 ++-- torch_em/self_training/mean_teacher.py | 4 ++-- torch_em/trainer/spoco_trainer.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/torch_em/self_training/fix_match.py b/torch_em/self_training/fix_match.py index 7520f90c..c92a7176 100644 --- a/torch_em/self_training/fix_match.py +++ b/torch_em/self_training/fix_match.py @@ -136,7 +136,7 @@ def __init__( # functionality for saving checkpoints and initialization # - def save_checkpoint(self, name, best_metric, **extra_save_dict): + def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): train_loader_kwargs = get_constructor_arguments(self.train_loader) val_loader_kwargs = get_constructor_arguments(self.val_loader) extra_state = { @@ -152,7 +152,7 @@ def save_checkpoint(self, name, best_metric, **extra_save_dict): }, } extra_state.update(**extra_save_dict) - super().save_checkpoint(name, best_metric, **extra_state) + super().save_checkpoint(name, current_metric, best_metric, **extra_state) # distribution alignment - encourages the distribution of the model's generated pseudo labels to match the marginal # distribution of pseudo labels from the source transfer diff --git a/torch_em/self_training/mean_teacher.py b/torch_em/self_training/mean_teacher.py index 2e5c39b0..3820ae2f 100644 --- a/torch_em/self_training/mean_teacher.py +++ b/torch_em/self_training/mean_teacher.py @@ -171,7 +171,7 @@ def _momentum_update(self): # functionality for saving checkpoints and initialization # - def save_checkpoint(self, name, best_metric, **extra_save_dict): + def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): train_loader_kwargs = get_constructor_arguments(self.train_loader) val_loader_kwargs = get_constructor_arguments(self.val_loader) extra_state = { @@ -188,7 +188,7 @@ def save_checkpoint(self, name, best_metric, **extra_save_dict): }, } extra_state.update(**extra_save_dict) - super().save_checkpoint(name, best_metric, **extra_state) + super().save_checkpoint(name, current_metric, best_metric, **extra_state) def load_checkpoint(self, checkpoint="best"): save_dict = super().load_checkpoint(checkpoint) diff --git a/torch_em/trainer/spoco_trainer.py b/torch_em/trainer/spoco_trainer.py index ef2c9568..68c125f1 100644 --- a/torch_em/trainer/spoco_trainer.py +++ b/torch_em/trainer/spoco_trainer.py @@ -32,8 +32,10 @@ def _momentum_update(self): for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()): param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum) - def save_checkpoint(self, name, best_metric, **extra_save_dict): - super().save_checkpoint(name, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict) + def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): + super().save_checkpoint( + name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict + ) def load_checkpoint(self, checkpoint="best"): save_dict = super().load_checkpoint(checkpoint)