Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add current_metric while saving checkpoints #223

Merged
merged 2 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torch_em/self_training/fix_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
# functionality for saving checkpoints and initialization
#

def save_checkpoint(self, name, best_metric, **extra_save_dict):
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
train_loader_kwargs = get_constructor_arguments(self.train_loader)
val_loader_kwargs = get_constructor_arguments(self.val_loader)
extra_state = {
Expand All @@ -152,7 +152,7 @@ def save_checkpoint(self, name, best_metric, **extra_save_dict):
},
}
extra_state.update(**extra_save_dict)
super().save_checkpoint(name, best_metric, **extra_state)
super().save_checkpoint(name, current_metric, best_metric, **extra_state)

# distribution alignment - encourages the distribution of the model's generated pseudo labels to match the marginal
# distribution of pseudo labels from the source transfer
Expand Down
4 changes: 2 additions & 2 deletions torch_em/self_training/mean_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _momentum_update(self):
# functionality for saving checkpoints and initialization
#

def save_checkpoint(self, name, best_metric, **extra_save_dict):
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
train_loader_kwargs = get_constructor_arguments(self.train_loader)
val_loader_kwargs = get_constructor_arguments(self.val_loader)
extra_state = {
Expand All @@ -188,7 +188,7 @@ def save_checkpoint(self, name, best_metric, **extra_save_dict):
},
}
extra_state.update(**extra_save_dict)
super().save_checkpoint(name, best_metric, **extra_state)
super().save_checkpoint(name, current_metric, best_metric, **extra_state)

def load_checkpoint(self, checkpoint="best"):
save_dict = super().load_checkpoint(checkpoint)
Expand Down
12 changes: 8 additions & 4 deletions torch_em/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,15 @@ def _initialize(self, iterations, load_from_checkpoint, epochs=None):
best_metric = np.inf
return best_metric

def save_checkpoint(self, name, best_metric, train_time=0.0, **extra_save_dict):
def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **extra_save_dict):
save_path = os.path.join(self.checkpoint_folder, f"{name}.pt")
extra_init_dict = extra_save_dict.pop("init", {})
save_dict = {
"iteration": self._iteration,
"epoch": self._epoch,
"best_epoch": self._best_epoch,
"best_metric": best_metric,
"current_metric": current_metric,
"model_state": self.model.state_dict(),
"optimizer_state": self.optimizer.state_dict(),
"init": self.init_data | extra_init_dict,
Expand Down Expand Up @@ -494,6 +495,7 @@ def load_checkpoint(self, checkpoint="best"):
self._epoch = save_dict["epoch"]
self._best_epoch = save_dict["best_epoch"]
self.best_metric = save_dict["best_metric"]
self.current_metric = save_dict["current_metric"]
self.train_time = save_dict.get("train_time", 0.0)

model_state = save_dict["model_state"]
Expand Down Expand Up @@ -573,14 +575,16 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
if current_metric < best_metric:
best_metric = current_metric
self._best_epoch = self._epoch
self.save_checkpoint("best", best_metric, train_time=total_train_time)
self.save_checkpoint("best", current_metric, best_metric, train_time=total_train_time)

# save this checkpoint as the latest checkpoint
self.save_checkpoint("latest", best_metric, train_time=total_train_time)
self.save_checkpoint("latest", current_metric, best_metric, train_time=total_train_time)

# if we save after every k-th epoch then check if we need to save now
if save_every_kth_epoch is not None and (self._epoch + 1) % save_every_kth_epoch == 0:
self.save_checkpoint(f"epoch-{self._epoch + 1}", best_metric, train_time=total_train_time)
self.save_checkpoint(
f"epoch-{self._epoch + 1}", current_metric, best_metric, train_time=total_train_time
)

# if early stopping has been specified then check if the stopping condition is met
if self.early_stopping is not None:
Expand Down
6 changes: 4 additions & 2 deletions torch_em/trainer/spoco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def _momentum_update(self):
for param_model, param_teacher in zip(self.model.parameters(), self.model2.parameters()):
param_teacher.data = param_teacher.data * self.momentum + param_model.data * (1. - self.momentum)

def save_checkpoint(self, name, best_metric, **extra_save_dict):
super().save_checkpoint(name, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict)
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
super().save_checkpoint(
name, current_metric, best_metric, model2_state=self.model2.state_dict(), **extra_save_dict
)

def load_checkpoint(self, checkpoint="best"):
save_dict = super().load_checkpoint(checkpoint)
Expand Down