Skip to content

Commit

Permalink
Fix: Fix the problem of different training times and modified the dis…
Browse files Browse the repository at this point in the history
…tributed training synchronization lock.
  • Loading branch information
chairc committed Sep 11, 2023
1 parent 5c4b49a commit c0157d9
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,6 @@ def train(rank=None, args=None):
tb_logger.add_scalar(tag=f"[{device}]: MSE", scalar_value=loss.item(),
global_step=epoch * len_dataloader + i)

# Synchronization during distributed training
if distributed:
dist.barrier()

# Saving and validating models in the main process
if save_models:
# Saving model
Expand Down Expand Up @@ -295,6 +291,12 @@ def train(rank=None, args=None):
logger.info(msg=f"Save the {save_name}.pt, ema_{save_name}.pt, and optim_{save_name}.pt.")
logger.info(msg="Save the model.")
logger.info(msg=f"[{device}]: Finish epoch {epoch}:")

# Synchronization during distributed training
if distributed:
logger.info(msg=f"[{device}]: Synchronization during distributed training.")
dist.barrier()

logger.info(msg=f"[{device}]: Finish training.")

# Clean up the distributed environment
Expand Down

0 comments on commit c0157d9

Please sign in to comment.