Skip to content

Commit

Permalink
Enable passing external pbar to trainer (#232)
Browse files Browse the repository at this point in the history
Enable passing external pbar to trainer
  • Loading branch information
constantinpape committed Apr 12, 2024
1 parent 3a02a99 commit da1f5a9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
1 change: 1 addition & 0 deletions environment_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ name:
torch-em-cpu
dependencies:
- affogato
- bioimageio.spec <0.5.0
- bioimageio.core >=0.5.0
- cpuonly
- imagecodecs
Expand Down
1 change: 1 addition & 0 deletions environment_gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ name:
torch-em
dependencies:
- affogato
- bioimageio.spec <0.5.0
- bioimageio.core >=0.5.0
- imagecodecs
- python-elf
Expand Down
23 changes: 17 additions & 6 deletions torch_em/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,14 @@ def load_checkpoint(self, checkpoint="best"):

return save_dict

def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_every_kth_epoch=None):
def fit(
self,
iterations=None,
load_from_checkpoint=None,
epochs=None,
save_every_kth_epoch=None,
progress=None,
):
"""Run neural network training.
Exactly one of 'iterations' or 'epochs' has to be passed.
Expand All @@ -527,6 +534,8 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
epochs [int] - how long to train, specified in epochs (default: None)
save_every_kth_epoch [int] - save checkpoints after every kth epoch separately.
The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. (default: None)
progress [progress_bar] - optional progress bar for integration with external tools.
Expected to follow the tqdm interface.
"""
best_metric = self._initialize(iterations, load_from_checkpoint, epochs)
print(
Expand All @@ -547,12 +556,14 @@ def fit(self, iterations=None, load_from_checkpoint=None, epochs=None, save_ever
validate = self._validate
print("Training with single precision")

progress = tqdm(
total=epochs * len(self.train_loader) if iterations is None else iterations,
desc=f"Epoch {self._epoch}", leave=True
)
msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
total_iterations = epochs * len(self.train_loader) if iterations is None else iterations
if progress is None:
progress = tqdm(total=total_iterations, desc=f"Epoch {self._epoch}", leave=True)
else:
progress.total = total_iterations
progress.set_description(f"Epoch {self._epoch}")

msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f"
train_epochs = self.max_epoch - self._epoch
t_start = time.time()
for _ in range(train_epochs):
Expand Down

0 comments on commit da1f5a9

Please sign in to comment.