Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training Tips for multiple GPUs may be invalid! #78

Closed
crj1998 opened this issue Oct 18, 2022 · 3 comments
Closed

Training Tips for multiple GPUs may be invalid! #78

crj1998 opened this issue Oct 18, 2022 · 3 comments
Labels

Comments

@crj1998
Copy link

crj1998 commented Oct 18, 2022

Dear @davda54 :
In the README the author metion a training tip for multiple GPUs. However, in fact, the above code snippet is invalid. The forward step should also warpped into no_sync context. As shown in the following code snippet

for input, output in data:
  # first forward-backward pass
  with model.no_sync():  # <- this is the important line
    loss = loss_function(output, model(input))  # move forward into `no_sync` context
    loss.backward()
  optimizer.first_step(zero_grad=True)
  
  # second forward-backward pass
  loss_function(output, model(input)).backward()
  optimizer.second_step(zero_grad=True)

The reason for this is a bit complicated, and we need to refer to the source code of pytorch .

When we refer to the source code of no_sync:
https://github.com/pytorch/pytorch/blob/8b0cc9c752477238cacfa171abf5061bc08bed28/torch/nn/parallel/distributed.py#L945-L967
we can find that when we enter the context of no_sync, the require_backward_grad_sync will be set to False, otherwise it will be True by default.
Also, the source code of forward (https://github.com/pytorch/pytorch/blob/8b0cc9c752477238cacfa171abf5061bc08bed28/torch/nn/parallel/distributed.py#L1005-L1010) will check the value of require_backward_grad_sync, and invoke prepare_for_forward if require_backward_grad_sync==True. So, only wrap backward step into no_sync is not enough.

In short, ddp is already preparing for gradient all_reduce operations of backpropagation during forward propagation. Therefore, the forward step must be wrapped in no_sync to be truly effective.

Also, we provide a simple real-world example here. It will train the resnet18 on the cifar10 with SGD+SAM under DDP.
demo.zip

CUDA_VISIBLE_DEVICES=0,1 torchrun  --nproc_per_node 2 demo.py

change the code in demo.py we have the following three format. W1 is the readme implementation. W2 is without no_sync, W3 is our format. When fixing seed, the W1 and W2 have the same result, which means that it is invalid.

W1
# first forward-backward pass
logits = model(inputs)
loss = criterion(logits, targets)
with model.no_sync():  # <- this is the important line
    loss.backward()
optimizer.first_step(zero_grad=True)

W2
# first forward-backward pass
logits = model(inputs)
loss = criterion(logits, targets)
loss.backward()
optimizer.first_step(zero_grad=True)

W3
# first forward-backward pass
with model.no_sync():  # <- this is the important line
    logits = model(inputs)
    loss = criterion(logits, targets)
    loss.backward()
optimizer.first_step(zero_grad=True)
@crj1998 crj1998 changed the title Training Tips for multiple GPUs may be wrong ! Training Tips for multiple GPUs may be invalid! Oct 18, 2022
@cyugao
Copy link

cyugao commented Oct 20, 2022

I agree! I was troubleshooting why I kept getting identical gradients on different GPUs, and only later I found out I didn't include the forward step in the no_sync context.

@cyugao
Copy link

cyugao commented Oct 20, 2022

My minimal example (adapted from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) using gradient accumulation is given below

import argparse
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size, grad_accum):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    toymodel = ToyModel().to(rank)
    model = DDP(toymodel, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    optimizer.zero_grad()
    for epoch in range(5):
        ### Main Loop ###
        if rank == 0:
            print(f"\n----Epoch {epoch} started")
        for i in range(grad_accum):
            outputs = model(torch.randn(20, 10))
            labels = torch.randn(20, 5).to(rank)
            loss = loss_fn(outputs, labels) / grad_accum
            if i == grad_accum - 1:
                loss.backward()
            else:
                with model.no_sync():
                    print(f"No sync, in epoch {epoch} iter {i}, on rank {rank}, loss={loss:5f}")
                    if epoch > 0 or i > 0:
                        print(f"On rank {rank}: Before step, {next(model.parameters()).grad.flatten()[:5].tolist()}")
                    loss.backward()
                    print(f"On rank {rank}:  After step, {next(model.parameters()).grad.flatten()[:5].tolist()}")

        optimizer.step()
        optimizer.zero_grad()

    cleanup()


def run_demo(demo_fn, world_size, grad_accum):
    mp.spawn(demo_fn,
             args=(world_size, grad_accum),
             nprocs=world_size,
             join=True)

if __name__ == "__main__":
    # parse args
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpus", "-g", type=int, default=2)
    parser.add_argument("--grad_accum", "-ga", type=int, default=3)
    args = parser.parse_args()
    run_demo(demo_basic, args.gpus, args.grad_accum)

A fix is given below by replacing the MAIN LOOP with

import contextlib
# ...
        ### Main Loop ###
        if rank == 0:
            print(f"\n----Epoch {epoch} started")
        for i in range(grad_accum):
            context = model.no_sync() if (i != grad_accum - 1) else contextlib.suppress()
            with context:
                outputs = model(torch.randn(20, 10))
                labels = torch.randn(20, 5).to(rank)
                loss = loss_fn(outputs, labels) / grad_accum
                if i == grad_accum - 1:
                    loss.backward()
                else:
                    print(f"No sync, in epoch {epoch} iter {i}, on rank {rank}, loss={loss:5f}")
                    if epoch > 0 or i > 0:
                        print(f"On rank {rank}: Before step, {next(model.parameters()).grad.flatten()[:5].tolist()}")
                    loss.backward()
                    print(f"On rank {rank}:  After step, {next(model.parameters()).grad.flatten()[:5].tolist()}")
        optimizer.step()
        optimizer.zero_grad()
# ...

@stale
Copy link

stale bot commented Nov 12, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale label Nov 12, 2022
@stale stale bot closed this as completed Nov 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants