-
Notifications
You must be signed in to change notification settings - Fork 196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training Tips for multiple GPUs may be invalid! #78
Comments
I agree! I was troubleshooting why I kept getting identical gradients on different GPUs, and only later I found out I didn't include the forward step in the |
My minimal example (adapted from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) using gradient accumulation is given below import argparse
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size, grad_accum):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
toymodel = ToyModel().to(rank)
model = DDP(toymodel, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
optimizer.zero_grad()
for epoch in range(5):
### Main Loop ###
if rank == 0:
print(f"\n----Epoch {epoch} started")
for i in range(grad_accum):
outputs = model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss = loss_fn(outputs, labels) / grad_accum
if i == grad_accum - 1:
loss.backward()
else:
with model.no_sync():
print(f"No sync, in epoch {epoch} iter {i}, on rank {rank}, loss={loss:5f}")
if epoch > 0 or i > 0:
print(f"On rank {rank}: Before step, {next(model.parameters()).grad.flatten()[:5].tolist()}")
loss.backward()
print(f"On rank {rank}: After step, {next(model.parameters()).grad.flatten()[:5].tolist()}")
optimizer.step()
optimizer.zero_grad()
cleanup()
def run_demo(demo_fn, world_size, grad_accum):
mp.spawn(demo_fn,
args=(world_size, grad_accum),
nprocs=world_size,
join=True)
if __name__ == "__main__":
# parse args
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", "-g", type=int, default=2)
parser.add_argument("--grad_accum", "-ga", type=int, default=3)
args = parser.parse_args()
run_demo(demo_basic, args.gpus, args.grad_accum) A fix is given below by replacing the MAIN LOOP with import contextlib
# ...
### Main Loop ###
if rank == 0:
print(f"\n----Epoch {epoch} started")
for i in range(grad_accum):
context = model.no_sync() if (i != grad_accum - 1) else contextlib.suppress()
with context:
outputs = model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss = loss_fn(outputs, labels) / grad_accum
if i == grad_accum - 1:
loss.backward()
else:
print(f"No sync, in epoch {epoch} iter {i}, on rank {rank}, loss={loss:5f}")
if epoch > 0 or i > 0:
print(f"On rank {rank}: Before step, {next(model.parameters()).grad.flatten()[:5].tolist()}")
loss.backward()
print(f"On rank {rank}: After step, {next(model.parameters()).grad.flatten()[:5].tolist()}")
optimizer.step()
optimizer.zero_grad()
# ... |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
Dear @davda54 :
In the
README
the author metion a training tip for multiple GPUs. However, in fact, the above code snippet is invalid. The forward step should also warpped intono_sync
context. As shown in the following code snippetThe reason for this is a bit complicated, and we need to refer to the source code of pytorch .
When we refer to the source code of
no_sync
:https://github.com/pytorch/pytorch/blob/8b0cc9c752477238cacfa171abf5061bc08bed28/torch/nn/parallel/distributed.py#L945-L967
we can find that when we enter the context of
no_sync
, therequire_backward_grad_sync
will be set toFalse
, otherwise it will beTrue
by default.Also, the source code of
forward
(https://github.com/pytorch/pytorch/blob/8b0cc9c752477238cacfa171abf5061bc08bed28/torch/nn/parallel/distributed.py#L1005-L1010) will check the value ofrequire_backward_grad_sync
, and invokeprepare_for_forward
ifrequire_backward_grad_sync==True
. So, only wrapbackward
step intono_sync
is not enough.In short, ddp is already preparing for gradient all_reduce operations of backpropagation during forward propagation. Therefore, the forward step must be wrapped in no_sync to be truly effective.
Also, we provide a simple real-world example here. It will train the resnet18 on the cifar10 with SGD+SAM under DDP.
demo.zip
change the code in
demo.py
we have the following three format. W1 is the readme implementation. W2 is withoutno_sync
, W3 is our format. When fixing seed, the W1 and W2 have the same result, which means that it is invalid.The text was updated successfully, but these errors were encountered: