Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training] Resume checkpoint global step inconsistent/confusion across scripts #8296

Open
vinm007 opened this issue May 28, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@vinm007
Copy link

vinm007 commented May 28, 2024

Describe the bug

Hi,
I have been working on training scripts for multiple models (T2I, IP2P) and found the different logic to calculate step and epoch while resuming training different across scripts.
In train_text_to_image.py script link

accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch

In train_instruct_pix2pix.py script link

accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

In the similar issue, some changes are made for the progress bar inconsistency but I am bit confused with the following things:-

  1. The multiplication of args.gradient_accumulation_steps in train_instruct_pix2pix.py script
  2. In general, when does global-step indicate and how does it's being updated, in both the scripts I can see the following code but couldn't understand it from accelerate documentation
 if accelerator.sync_gradients:
    if args.use_ema:
        ema_unet.step(unet.parameters())
    progress_bar.update(1)
    global_step += 1
    accelerator.log({"train_loss": train_loss}, step=global_step)
    train_loss = 0.0

If we are using multiple GPUs with gradient accumulation, at what event global_step is updated- is it being updated independently by each GPU (since the code is not wrapped with accelerator.is_main_process), also how accumulation affecting the tracking here?

Reproduction

Logs

No response

System Info

Who can help?

@sayakpaul

@vinm007 vinm007 added the bug Something isn't working label May 28, 2024
@sayakpaul
Copy link
Member

Sorry that you're facing confusion.

The multiplication of args.gradient_accumulation_steps in train_instruct_pix2pix.py script

Why should it not be the case? It's based on steps and without the GA steps, the calculation would be improper, no?

Ccing @muellerzr for further clarification in light of accelerate.

@vinm007
Copy link
Author

vinm007 commented May 29, 2024

Sorry that you're facing confusion.

The multiplication of args.gradient_accumulation_steps in train_instruct_pix2pix.py script

Why should it not be the case? It's based on steps and without the GA steps, the calculation would be improper, no?

Ccing @muellerzr for further clarification in light of accelerate.

I am not sure about the calculation but do find it different in these two scripts. Is one of them outdated or wrong?
To deduce this calculation, I tried to understand the how global_step is updated but couldn't understand. In general, it is incremented by 1 when the accelerator.sync_gradients is true. The following code is used to update global_step

 if accelerator.sync_gradients:
    if args.use_ema:
        ema_unet.step(unet.parameters())
    progress_bar.update(1)
    global_step += 1
    accelerator.log({"train_loss": train_loss}, step=global_step)
    train_loss = 0.0

What does this code imply?
Is this counter updated by each gpu (multiple process scenario) or not? Does this sync_gradient flag takes care of gradient accumulation or not? Based on that only, I can deduce the calculation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants