Skip to content

Commit

Permalink
FG-2223 fix clock log order
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed May 21, 2024
1 parent cc7b62f commit dd8c369
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 5 additions & 2 deletions src/refiners/training_utils/clock.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,19 @@ def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:

def on_step_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
if self.num_minibatches_processed == 0:
if self.iteration > 0:
self.log(f"Iteration {self.iteration - 1} ended.")
self.log(f"Iteration {self.iteration} started.")
self.log(f"Step {self.step} started.")

def on_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Step {self.step} ended.")
self.step += 1
self.num_batches_processed += 1

def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.num_minibatches_processed += 1
self.num_batches_processed += 1

def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
self.log(f"Iteration {self.iteration} ended.")
self.iteration += 1
self.num_minibatches_processed = 0
4 changes: 2 additions & 2 deletions tests/training_utils/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,11 @@ def test_callback_registration(mock_trainer: MockTrainer) -> None:
mock_trainer.train()

# Check that the callback skips every other iteration
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2
assert mock_trainer.mock_callback.optimizer_step_count == mock_trainer.clock.iteration // 2 + 1
assert mock_trainer.mock_callback.step_end_count == mock_trainer.clock.step // 3 + 1

# Check that the random seed was set
assert mock_trainer.mock_callback.optimizer_step_random_int == 93
assert mock_trainer.mock_callback.optimizer_step_random_int == 41
assert mock_trainer.mock_callback.step_end_random_int == 81


Expand Down

0 comments on commit dd8c369

Please sign in to comment.