Skip to content

Fix correct gradient accumulation#407

Merged
vwxyzjn merged 5 commits intomainfrom
fix-grad-acc-bug
Jun 14, 2023
Merged

Fix correct gradient accumulation#407
vwxyzjn merged 5 commits intomainfrom
fix-grad-acc-bug

Conversation

@younesbelkada
Copy link
Copy Markdown
Contributor

What does this PR do?

This PR correctly fixes the gradient accumulation, firstly introduced in #220

Fixes #321

cc @lvwerra

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Jun 6, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Copy Markdown
Contributor

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR makes sense! Would it be possible to verify that it in fact can reproduce the exact same gradient with different number of gradient accumulation steps?

Here is a snippet:

import torch
from torch.utils.data import TensorDataset, DataLoader
import copy

# seed
torch.manual_seed(0)

# define toy inputs and labels
x = torch.tensor([1., 2., 3., 4.])
y = torch.tensor([2., 4., 6., 8.])

# define dataset and dataloader
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=2)

# define model, optimizer and loss function
model = torch.nn.Linear(1, 1)

# clone the model
model_clone = copy.deepcopy(model)
criterion = torch.nn.MSELoss()
accumulation_steps = 2

# loop over batches
for i, (inputs, labels) in enumerate(dataloader):
    # reshape inputs and labels
    inputs = inputs.view(-1, 1)
    labels = labels.view(-1, 1)
    # forward pass
    outputs = model(inputs)
    loss = criterion(outputs, labels) / accumulation_steps
    # backward pass
    loss.backward()
    # check if accumulation is done
    if (i + 1) % accumulation_steps == 0:
        print("w/ accumulation, the final model grad is", model.weight.grad) 
        break

loss = criterion(model_clone(x.view(-1, 1)), y.view(-1, 1))
loss.backward()
print("w/o accumulation, the final model grad is", model_clone.weight.grad)
w/ accumulation, the final model grad is tensor([[-27.4301]])
w/o accumulation, the final model grad is tensor([[-27.4301]])

Copy link
Copy Markdown
Contributor Author

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes definitelty! Will add a test for that, thanks a lot for the snippet and the pointer!

@younesbelkada younesbelkada requested a review from vwxyzjn June 14, 2023 11:15
Copy link
Copy Markdown
Contributor

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks so much @younesbelkada

@vwxyzjn vwxyzjn merged commit 61af5f2 into main Jun 14, 2023
@vwxyzjn vwxyzjn deleted the fix-grad-acc-bug branch June 14, 2023 12:43
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
* add correct grad acc

* add some tests but they fail

* test should pass

* style

* fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Why is the backward step in ppo_trainer not handled by accelerate's accumulate?

3 participants